Skip to content

Commit 87e8a5d

Browse files
Implement DML copy for Lora Adapters (#22396)
### Description Request and create DML EP and its data transfer. Use to copy on device. The PR includes changes to fix issues in DML provider. ### Motivation and Context This enables Lora users to run it with DML which is important for GenAI. Co-authored-by: @PatriceVignola --------- Co-authored-by: Patrice Vignola <[email protected]>
1 parent 35adba2 commit 87e8a5d

File tree

11 files changed

+136
-78
lines changed

11 files changed

+136
-78
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ namespace Dml
186186
}
187187
else
188188
{
189-
if (!m_context->IsClosed())
189+
if (!m_closed)
190190
{
191191
// Free the underlying allocation once queued work has completed.
192192
#ifdef _GAMING_XBOX

onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ namespace Dml
4646

4747
void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);
4848

49+
void Close()
50+
{
51+
m_closed = true;
52+
}
53+
4954
public: // onnxruntime::IAllocator
5055
void* Alloc(size_t size, AllocatorRoundingMode roundingMode);
5156
void* Alloc(size_t size) final;
@@ -83,6 +88,7 @@ namespace Dml
8388
std::vector<Bucket> m_pool;
8489
size_t m_currentAllocationId = 0;
8590
uint64_t m_currentResourceId = 0;
91+
bool m_closed = false;
8692

8793
// Unless specifically requested, allocation sizes are not rounded to enable pooling
8894
// until SetDefaultRoundingMode is called. This should be done at completion of session

onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ namespace Dml
5555
// for example, an allocation from BucketizedBufferAllocator attempts to queue a reference
5656
// to its underlying D3D resource when freed. Furthermore, these references are unnecessary
5757
// since Close() already blocks for scheduled GPU work before clearing m_queuedReferences.
58-
if (!m_closing)
58+
if (!m_clearingQueue)
5959
{
6060
QueuedReference queuedReference = {GetLastFenceValue(), object};
6161

@@ -70,15 +70,15 @@ namespace Dml
7070
}
7171
}
7272

73-
void CommandQueue::Close()
73+
void CommandQueue::WaitForSignalAndClearQueue()
7474
{
7575
// Wait for flushed work:
76-
assert(!m_closing);
77-
m_closing = true;
76+
assert(!m_clearingQueue);
77+
m_clearingQueue = true;
7878
GpuEvent event = GetCurrentCompletionEvent();
7979
event.WaitForSignal(m_cpuSyncSpinningEnabled);
8080
m_queuedReferences.clear();
81-
m_closing = false;
81+
m_clearingQueue = false;
8282
}
8383

8484
void CommandQueue::ReleaseCompletedReferences()

onnxruntime/core/providers/dml/DmlExecutionProvider/src/CommandQueue.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace Dml
4444
}
4545
#endif
4646

47-
void Close();
47+
void WaitForSignalAndClearQueue();
4848
void ReleaseCompletedReferences();
4949

5050
private:
@@ -61,7 +61,7 @@ namespace Dml
6161

6262
ComPtr<ID3D12Fence> m_fence;
6363
uint64_t m_lastFenceValue = 0;
64-
bool m_closing = false;
64+
bool m_clearingQueue = false;
6565
bool m_cpuSyncSpinningEnabled = false;
6666
};
6767

onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@ namespace Dml
1111
ID3D12Device* d3d12Device,
1212
IDMLDevice* dmlDevice,
1313
ID3D12CommandQueue* queue,
14-
bool cpuSyncSpinningEnabled,
15-
bool keepOpen
16-
)
14+
bool cpuSyncSpinningEnabled)
1715
: m_queue(std::make_shared<CommandQueue>(queue, cpuSyncSpinningEnabled))
1816
, m_dmlRecorder(d3d12Device, dmlDevice, m_queue)
1917
, m_cpuSyncSpinningEnabled(cpuSyncSpinningEnabled)
20-
, m_keepOpen(keepOpen)
2118
{
2219
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
2320
}
@@ -36,8 +33,6 @@ namespace Dml
3633
D3D12_RESOURCE_STATES srcState,
3734
uint64_t byteCount)
3835
{
39-
assert(!m_closed);
40-
4136
SetCommandRecorder(&m_dmlRecorder);
4237

4338
std::vector<D3D12_RESOURCE_BARRIER> barriers;
@@ -84,8 +79,6 @@ namespace Dml
8479
_Out_ uint64_t* completionValue
8580
)
8681
{
87-
assert(!m_closed);
88-
8982
SetCommandRecorder(&m_dmlRecorder);
9083
m_dmlRecorder.ExecuteCommandList(commandList, fence, completionValue);
9184
}
@@ -95,7 +88,6 @@ namespace Dml
9588
const DML_BINDING_DESC& persistentResourceBinding,
9689
const DML_BINDING_DESC& inputArrayBinding)
9790
{
98-
assert(!m_closed);
9991
SetCommandRecorder(&m_dmlRecorder);
10092

10193
m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
@@ -107,31 +99,27 @@ namespace Dml
10799
gsl::span<const DML_BINDING_DESC> inputBindings,
108100
gsl::span<const DML_BINDING_DESC> outputBindings)
109101
{
110-
assert(!m_closed);
111102
SetCommandRecorder(&m_dmlRecorder);
112103

113104
m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
114105
}
115106

116107
void ExecutionContext::AddUAVBarrier()
117108
{
118-
assert(!m_closed);
119109
SetCommandRecorder(&m_dmlRecorder);
120110

121111
m_dmlRecorder.AddUAVBarrier();
122112
}
123113

124114
void ExecutionContext::ResourceBarrier(gsl::span<const D3D12_RESOURCE_BARRIER> barriers)
125115
{
126-
assert(!m_closed);
127116
SetCommandRecorder(&m_dmlRecorder);
128117

129118
m_dmlRecorder.ResourceBarrier(barriers);
130119
}
131120

132121
void ExecutionContext::GetCommandListForRecordingAndInvalidateState(ID3D12GraphicsCommandList** commandList)
133122
{
134-
assert(!m_closed);
135123
SetCommandRecorder(&m_dmlRecorder);
136124

137125
// Ensure the descriptor heap is reset to D3D as something external may change it before recording
@@ -142,8 +130,6 @@ namespace Dml
142130

143131
void ExecutionContext::SetCommandRecorder(ICommandRecorder* newRecorder)
144132
{
145-
assert(!m_closed);
146-
147133
// If changing which recorder is the current one, we need to flush the old one first. This is to ensure correct
148134
// ordering of operations on the command queue.
149135
if (m_currentRecorder != newRecorder)
@@ -160,8 +146,6 @@ namespace Dml
160146

161147
void ExecutionContext::Flush()
162148
{
163-
assert(!m_closed);
164-
165149
if (!m_currentRecorder || !m_currentRecorder->HasUnsubmittedWork())
166150
{
167151
// Nothing to flush
@@ -180,34 +164,21 @@ namespace Dml
180164

181165
void ExecutionContext::QueueReference(IUnknown* object)
182166
{
183-
assert(!m_closed);
184167
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
185168
// value is the one to signal completion.
186169
bool waitForUnsubmittedWork = (m_currentRecorder != nullptr);
187170
m_queue->QueueReference(object, waitForUnsubmittedWork);
188171
}
189172

190-
void ExecutionContext::Close()
173+
void ExecutionContext::WaitForSignalAndClearQueue()
191174
{
192-
assert(!m_closed);
193-
194175
// Discard unflushed work and clear queued references. This prevents the circular reference:
195176
// Kernel --> ProviderImpl --> Context --> QueuedRefs --> Kernel
196-
m_queue->Close();
197-
198-
// Keep the execution context open when requested, e.g. when used through the python API where there's a single context
199-
// and single command queue
200-
if (!m_keepOpen)
201-
{
202-
m_currentRecorder = nullptr;
203-
m_closed = true;
204-
}
177+
m_queue->WaitForSignalAndClearQueue();
205178
}
206179

207180
GpuEvent ExecutionContext::GetCurrentCompletionEvent()
208181
{
209-
assert(!m_closed);
210-
211182
GpuEvent event = m_queue->GetCurrentCompletionEvent();
212183

213184
// If something has been recorded into a command list but not submitted yet, it means that the *next* fence
@@ -223,13 +194,11 @@ namespace Dml
223194

224195
void ExecutionContext::ReleaseCompletedReferences()
225196
{
226-
assert(!m_closed);
227197
m_queue->ReleaseCompletedReferences();
228198
}
229199

230200
D3D12_COMMAND_LIST_TYPE ExecutionContext::GetCommandListTypeForQueue() const
231201
{
232-
assert(!m_closed);
233202
return m_queue->GetType();
234203
}
235204

onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,13 @@ namespace Dml
2323
ID3D12Device* d3d12Device,
2424
IDMLDevice* dmlDevice,
2525
ID3D12CommandQueue* queue,
26-
bool cpuSyncSpinningEnabled,
27-
bool keepOpen);
26+
bool cpuSyncSpinningEnabled);
2827

2928
void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);
3029

3130
// Waits for flushed work, discards unflushed work, and discards associated references to
32-
// prevent circular references. Must be the last call on the object before destruction.
33-
void Close();
31+
// prevent circular references.
32+
void WaitForSignalAndClearQueue();
3433

3534
// Queues a CopyBufferRegion (see ID3D12GraphicsCommandList::CopyBufferRegion) for execution. Transition
3635
// barriers are automatically inserted to transition the source and destination resources to COPY_SOURCE and
@@ -87,7 +86,6 @@ namespace Dml
8786

8887
D3D12_COMMAND_LIST_TYPE GetCommandListTypeForQueue() const;
8988
bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; }
90-
bool IsClosed() const { return m_closed; }
9189

9290
private:
9391
Microsoft::WRL::ComPtr<ID3D12Device> m_d3dDevice;
@@ -103,10 +101,6 @@ namespace Dml
103101

104102
bool m_closed = false;
105103
bool m_cpuSyncSpinningEnabled = false;
106-
107-
// The python API has a global state used for I/O binding where the execution context is shared between session,
108-
// so we don't want to close the context when one of the sessions is destroyed
109-
bool m_keepOpen = false;
110104
};
111105

112106
} // namespace Dml

onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,26 @@ namespace Dml
106106
// Release the cached command list references before closing the context
107107
m_capturedGraphs.clear();
108108

109-
m_context->Close();
109+
// Close the allocator before clearing the command queue to stop it from
110+
// appending resources to it in an attempt to keep them alive.
111+
if (m_allocator)
112+
{
113+
m_allocator->Close();
114+
}
115+
116+
// Destroy the allocators. We are closing the execution provider, so from now on the
117+
// only thing it will be used for is doing copies via the DataTransfer, which doesn't
118+
// require allocating any memory.
119+
// TODO: Move the copy functions over to ExecutionContext so that we are able to cleanly
120+
// destroy ExecutionProviderImpl, and instead have the DataTransfer keep the context alive.
121+
m_allocator = nullptr;
122+
m_cpuInputAllocator = nullptr;
123+
124+
// Wait for all pending commands to be done executing and empty the command queue. This will
125+
// Force all kernels and resources in flight to get destroyed and, from this point forward,
126+
// ExecutionProviderImpl will only be used to execute transfer between resources that are
127+
// already existing via the DataTransfer;
128+
m_context->WaitForSignalAndClearQueue();
110129
}
111130

112131
void ExecutionProviderImpl::WaitForOutstandingWork()

onnxruntime/core/providers/dml/dml_provider_factory.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ std::unique_ptr<IExecutionProvider> DMLProviderFactory::CreateProvider() {
8686

8787
// First, check if an I/O binding API that was used before this session or another session has already created a queue
8888
if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, execution_context.GetAddressOf()))) {
89-
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true, true);
89+
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), true);
9090
ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, execution_context.Get()));
9191
}
9292
} else {
93-
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false);
93+
execution_context = wil::MakeOrThrow<Dml::ExecutionContext>(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_);
9494
}
9595

9696
auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_);

0 commit comments

Comments
 (0)