33#include < time.h>
44#include " Common.h"
55#include " Windows.AI.Machinelearning.Native.h"
6-
6+ # include " d3dx12.h "
77using namespace winrt ::Windows::Media;
88using namespace winrt ::Windows::Storage;
99using namespace winrt ::Windows::AI::MachineLearning;
@@ -327,7 +327,7 @@ namespace BindingUtilities
327327
328328 template <TensorKind T>
329329 static ITensor CreateTensor (const CommandLineArgs& args, const std::vector<std::string>& tensorStringInput,
330- const IVectorView<int64_t >& tensorShape)
330+ const IVectorView<int64_t >& tensorShape, const InputBindingType inputBindingType )
331331 {
332332 using TensorValue = typename TensorKindToValue<T>::Type;
333333 using DataType = typename TensorKindToType<T>::Type;
@@ -372,11 +372,106 @@ namespace BindingUtilities
372372 // Creating Tensors for Input Images haven't been added yet.
373373 throw hresult_not_implemented (L" Creating Tensors for Input Images haven't been implemented yet!" );
374374 }
375- return tensorValue;
375+
376+ if (inputBindingType == InputBindingType::CPU)
377+ {
378+ return tensorValue;
379+ }
380+ else // GPU Tensor
381+ {
382+ com_ptr<ID3D12Resource> pGPUResource = nullptr ;
383+ try
384+ {
385+ // create the d3d device.
386+ com_ptr<ID3D12Device> pD3D12Device = nullptr ;
387+ D3D12CreateDevice (nullptr , D3D_FEATURE_LEVEL::D3D_FEATURE_LEVEL_11_0, __uuidof (ID3D12Device),
388+ reinterpret_cast <void **>(&pD3D12Device));
389+
390+ pD3D12Device->CreateCommittedResource (
391+ &CD3DX12_HEAP_PROPERTIES (D3D12_HEAP_TYPE_DEFAULT),
392+ D3D12_HEAP_FLAG_NONE,
393+ &CD3DX12_RESOURCE_DESC::Buffer (
394+ actualSizeInBytes,
395+ D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS),
396+ D3D12_RESOURCE_STATE_COMMON, nullptr ,
397+ __uuidof (ID3D12Resource), pGPUResource.put_void ());
398+ if (!args.IsGarbageInput ())
399+ {
400+ com_ptr<ID3D12Resource> imageUploadHeap;
401+ // Create the GPU upload buffer.
402+ pD3D12Device->CreateCommittedResource (
403+ &CD3DX12_HEAP_PROPERTIES (D3D12_HEAP_TYPE_UPLOAD), D3D12_HEAP_FLAG_NONE,
404+ &CD3DX12_RESOURCE_DESC::Buffer (actualSizeInBytes), D3D12_RESOURCE_STATE_GENERIC_READ, nullptr ,
405+ __uuidof (ID3D12Resource), imageUploadHeap.put_void ());
406+
407+ // create the command queue.
408+ com_ptr<ID3D12CommandQueue> dxQueue = nullptr ;
409+ D3D12_COMMAND_QUEUE_DESC commandQueueDesc = {};
410+ commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
411+ pD3D12Device->CreateCommandQueue (&commandQueueDesc, __uuidof (ID3D12CommandQueue),
412+ reinterpret_cast <void **>(&dxQueue));
413+ com_ptr<ILearningModelDeviceFactoryNative> devicefactory =
414+ get_activation_factory<LearningModelDevice, ILearningModelDeviceFactoryNative>();
415+ com_ptr<::IUnknown> spUnk;
416+ devicefactory->CreateFromD3D12CommandQueue (dxQueue.get (), spUnk.put ());
417+
418+ // Create ID3D12GraphicsCommandList and Allocator
419+ D3D12_COMMAND_LIST_TYPE queuetype = dxQueue->GetDesc ().Type ;
420+ com_ptr<ID3D12CommandAllocator> alloctor;
421+ com_ptr<ID3D12GraphicsCommandList> cmdList;
422+ pD3D12Device->CreateCommandAllocator (queuetype, winrt::guid_of<ID3D12CommandAllocator>(),
423+ alloctor.put_void ());
424+ pD3D12Device->CreateCommandList (0 , queuetype, alloctor.get (), nullptr ,
425+ winrt::guid_of<ID3D12CommandList>(), cmdList.put_void ());
426+
427+ // Copy from Cpu to GPU
428+ D3D12_SUBRESOURCE_DATA CPUData = {};
429+ CPUData.pData = actualData;
430+ CPUData.RowPitch = actualSizeInBytes;
431+ CPUData.SlicePitch = actualSizeInBytes;
432+ UpdateSubresources (cmdList.get (), pGPUResource.get (), imageUploadHeap.get (), 0 , 0 , 1 , &CPUData);
433+
434+ // Close the command list and execute it to begin the initial GPU setup.
435+ cmdList->Close ();
436+ ID3D12CommandList* ppCommandLists[] = { cmdList.get () };
437+ dxQueue->ExecuteCommandLists (_countof (ppCommandLists), ppCommandLists);
438+
439+ // Create Event
440+ HANDLE directEvent = CreateEvent (nullptr , FALSE , FALSE , nullptr );
441+
442+ // Create Fence
443+ Microsoft::WRL::ComPtr<ID3D12Fence> spDirectFence = nullptr ;
444+ THROW_IF_FAILED (pD3D12Device->CreateFence (0 , D3D12_FENCE_FLAG_NONE,
445+ IID_PPV_ARGS (spDirectFence.ReleaseAndGetAddressOf ())));
446+ // Adds fence to queue
447+ THROW_IF_FAILED (dxQueue->Signal (spDirectFence.Get (), 1 ));
448+ THROW_IF_FAILED (spDirectFence->SetEventOnCompletion (1 , directEvent));
449+
450+ // Wait for signal
451+ DWORD retVal = WaitForSingleObject (directEvent, INFINITE);
452+ if (retVal != WAIT_OBJECT_0)
453+ {
454+ THROW_IF_FAILED (E_UNEXPECTED);
455+ }
456+ }
457+ }
458+ catch (...)
459+ {
460+ std::cout << " Couldn't create and copy CPU tensor resource to GPU resource" << std::endl;
461+ throw ;
462+ }
463+ com_ptr<ITensorStaticsNative> tensorfactory = get_activation_factory<TensorValue, ITensorStaticsNative>();
464+ com_ptr<::IUnknown> spUnkTensor;
465+ tensorfactory->CreateFromD3D12Resource (pGPUResource.get (), vecShape.data (), static_cast <int >(vecShape.size ()), spUnkTensor.put ());
466+ TensorValue returnTensor (nullptr );
467+ spUnkTensor.try_as (returnTensor);
468+ return returnTensor;
469+ }
376470 }
377471
378472 // Binds tensor floats, ints, doubles from CSV data.
379- ITensor CreateBindableTensor (const ILearningModelFeatureDescriptor& description, const CommandLineArgs& args)
473+ ITensor CreateBindableTensor (const ILearningModelFeatureDescriptor& description, const CommandLineArgs& args,
474+ const InputBindingType inputBindingType)
380475 {
381476 std::vector<std::string> elementStrings;
382477 if (!args.CsvPath ().empty ())
@@ -407,7 +502,7 @@ namespace BindingUtilities
407502 std::vector<int64_t > shape = { 1 , channels, imageFeatureDescriptor.Height (),
408503 imageFeatureDescriptor.Width () };
409504 IVectorView<int64_t > shapeVectorView = single_threaded_vector (std::move (shape)).GetView ();
410- return CreateTensor<TensorKind::Float>(args, elementStrings, shapeVectorView);
505+ return CreateTensor<TensorKind::Float>(args, elementStrings, shapeVectorView, inputBindingType );
411506 }
412507
413508 auto tensorDescriptor = description.try_as <TensorFeatureDescriptor>();
@@ -422,57 +517,68 @@ namespace BindingUtilities
422517 }
423518 case TensorKind::Float:
424519 {
425- return CreateTensor<TensorKind::Float>(args, elementStrings, tensorDescriptor.Shape ());
520+ return CreateTensor<TensorKind::Float>(args, elementStrings, tensorDescriptor.Shape (),
521+ inputBindingType);
426522 }
427523 break ;
428524 case TensorKind::Float16:
429525 {
430- return CreateTensor<TensorKind::Float16>(args, elementStrings, tensorDescriptor.Shape ());
526+ return CreateTensor<TensorKind::Float16>(args, elementStrings, tensorDescriptor.Shape (),
527+ inputBindingType);
431528 }
432529 break ;
433530 case TensorKind::Double:
434531 {
435- return CreateTensor<TensorKind::Double>(args, elementStrings, tensorDescriptor.Shape ());
532+ return CreateTensor<TensorKind::Double>(args, elementStrings, tensorDescriptor.Shape (),
533+ inputBindingType);
436534 }
437535 break ;
438536 case TensorKind::Int8:
439537 {
440- return CreateTensor<TensorKind::Int8>(args, elementStrings, tensorDescriptor.Shape ());
538+ return CreateTensor<TensorKind::Int8>(args, elementStrings, tensorDescriptor.Shape (),
539+ inputBindingType);
441540 }
442541 break ;
443542 case TensorKind::UInt8:
444543 {
445- return CreateTensor<TensorKind::UInt8>(args, elementStrings, tensorDescriptor.Shape ());
544+ return CreateTensor<TensorKind::UInt8>(args, elementStrings, tensorDescriptor.Shape (),
545+ inputBindingType);
446546 }
447547 break ;
448548 case TensorKind::Int16:
449549 {
450- return CreateTensor<TensorKind::Int16>(args, elementStrings, tensorDescriptor.Shape ());
550+ return CreateTensor<TensorKind::Int16>(args, elementStrings, tensorDescriptor.Shape (),
551+ inputBindingType);
451552 }
452553 break ;
453554 case TensorKind::UInt16:
454555 {
455- return CreateTensor<TensorKind::UInt16>(args, elementStrings, tensorDescriptor.Shape ());
556+ return CreateTensor<TensorKind::UInt16>(args, elementStrings, tensorDescriptor.Shape (),
557+ inputBindingType);
456558 }
457559 break ;
458560 case TensorKind::Int32:
459561 {
460- return CreateTensor<TensorKind::Int32>(args, elementStrings, tensorDescriptor.Shape ());
562+ return CreateTensor<TensorKind::Int32>(args, elementStrings, tensorDescriptor.Shape (),
563+ inputBindingType);
461564 }
462565 break ;
463566 case TensorKind::UInt32:
464567 {
465- return CreateTensor<TensorKind::UInt32>(args, elementStrings, tensorDescriptor.Shape ());
568+ return CreateTensor<TensorKind::UInt32>(args, elementStrings, tensorDescriptor.Shape (),
569+ inputBindingType);
466570 }
467571 break ;
468572 case TensorKind::Int64:
469573 {
470- return CreateTensor<TensorKind::Int64>(args, elementStrings, tensorDescriptor.Shape ());
574+ return CreateTensor<TensorKind::Int64>(args, elementStrings, tensorDescriptor.Shape (),
575+ inputBindingType);
471576 }
472577 break ;
473578 case TensorKind::UInt64:
474579 {
475- return CreateTensor<TensorKind::UInt64>(args, elementStrings, tensorDescriptor.Shape ());
580+ return CreateTensor<TensorKind::UInt64>(args, elementStrings, tensorDescriptor.Shape (),
581+ inputBindingType);
476582 }
477583 break ;
478584 }
0 commit comments