Skip to content

Commit 03da441

Browse files
pavanbalajimeta-codesync[bot]
authored andcommitted
Remove completed_work_queue_
Summary: The completed_work_queue_ is no longer needed. It is OK to free work objects and the corresponding tensors from the timeout thread as well. Reviewed By: tanquer Differential Revision: D85455174 fbshipit-source-id: e74fab99702515a1295533c79fec83f0e2c50adf
1 parent 1c24ac7 commit 03da441

File tree

6 files changed

+25
-44
lines changed

6 files changed

+25
-44
lines changed

comms/torchcomms/ncclx/TorchCommNCCLX.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class TorchCommNCCLX : public TorchCommBackend,
297297
void timeoutWatchdog() noexcept;
298298
void checkInitialized() const;
299299
void checkAndAbortIfTimedOutOrError();
300-
void checkWorkQueue(bool isMainThread);
300+
void checkWorkQueue();
301301
void enqueueWork(std::shared_ptr<TorchWorkNCCLX> work, cudaStream_t stream);
302302
bool getGraphCaptureMode();
303303
cudaStream_t getOperationStream(bool async_op);

comms/torchcomms/ncclx/TorchCommNCCLXUtils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ NcclxWindowCmpOp TorchCommNCCLX::getNcclSignalCmpOp(SignalCmpOp op) {
157157
#endif
158158
}
159159

160-
void TorchCommNCCLX::checkWorkQueue(bool isMainThread) {
161-
TorchWorkNCCLX::WorkStatus status = workq_.garbageCollect(isMainThread);
160+
void TorchCommNCCLX::checkWorkQueue() {
161+
TorchWorkNCCLX::WorkStatus status = workq_.garbageCollect();
162162

163163
switch (status) {
164164
case TorchWorkNCCLX::WorkStatus::TIMEDOUT:
@@ -192,7 +192,7 @@ void TorchCommNCCLX::timeoutWatchdog() noexcept {
192192
}
193193

194194
// Check work objects for completion or timeout
195-
checkWorkQueue(false);
195+
checkWorkQueue();
196196
if (comm_state_ != CommState::NORMAL &&
197197
options_.abort_process_on_timeout_or_error) {
198198
// Log the error and abort the process. We cannot abort the NCCL
@@ -226,7 +226,7 @@ void TorchCommNCCLX::checkAndAbortIfTimedOutOrError() {
226226
}
227227

228228
// First, check work queue status
229-
checkWorkQueue(true);
229+
checkWorkQueue();
230230

231231
if (comm_state_ == CommState::TIMEOUT) {
232232
abortNcclComm();

comms/torchcomms/ncclx/TorchWorkNCCLX.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,6 @@ TorchWorkNCCLX::WorkStatus TorchWorkNCCLX::checkStatus() {
121121
if (end_status == cudaSuccess) {
122122
// End event has completed, mark the work as completed
123123
state_ = WorkStatus::COMPLETED;
124-
125-
// Release the input tensors
126-
inputTensors_.clear();
127124
} else if (end_status == cudaErrorNotReady) {
128125
// End event has not completed yet, check for timeout
129126
auto current_time = std::chrono::steady_clock::now();

comms/torchcomms/ncclx/TorchWorkNCCLX.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,14 @@ class TorchWorkNCCLXQueue {
9595
TorchWorkNCCLXQueue() = default;
9696
~TorchWorkNCCLXQueue() = default;
9797

98-
TorchWorkNCCLX::WorkStatus garbageCollect(bool isMainThread);
98+
TorchWorkNCCLX::WorkStatus garbageCollect();
9999
// Finalize function can only be called from the main thread
100100
TorchWorkNCCLX::WorkStatus finalize();
101101
void enqueueWork(std::shared_ptr<TorchWorkNCCLX> work, cudaStream_t stream);
102102

103103
private:
104104
std::unordered_map<cudaStream_t, std::queue<std::shared_ptr<TorchWorkNCCLX>>>
105105
stream_work_queues_;
106-
std::vector<std::shared_ptr<TorchWorkNCCLX>> completed_work_queue_;
107106
std::recursive_mutex work_queues_mutex_;
108107

109108
friend class TorchWorkNCCLXQueueCommTest;

comms/torchcomms/ncclx/TorchWorkNCCLXQueue.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
namespace torch {
66
namespace comms {
77

8-
TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::garbageCollect(
9-
bool isMainThread) {
8+
TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::garbageCollect() {
109
std::lock_guard<std::recursive_mutex> lock(work_queues_mutex_);
1110

1211
TorchWorkNCCLX::WorkStatus last_status =
@@ -30,7 +29,6 @@ TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::garbageCollect(
3029
if (status == TorchWorkNCCLX::WorkStatus::COMPLETED) {
3130
// Work is completed, remove it from the work queue
3231
work_queue.pop();
33-
completed_work_queue_.push_back(work);
3432
// Continue to the next element in the queue
3533
} else if (
3634
status == TorchWorkNCCLX::WorkStatus::TIMEDOUT ||
@@ -51,11 +49,6 @@ TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::garbageCollect(
5149
}
5250
}
5351

54-
if (isMainThread) {
55-
// If we are the main thread, clear the completed work queues
56-
completed_work_queue_.clear();
57-
}
58-
5952
return last_status;
6053
}
6154

@@ -71,7 +64,7 @@ TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::finalize() {
7164
// empty
7265
TorchWorkNCCLX::WorkStatus status = TorchWorkNCCLX::WorkStatus::COMPLETED;
7366
while (!stream_work_queues_.empty()) {
74-
status = garbageCollect(true);
67+
status = garbageCollect();
7568
if (status == TorchWorkNCCLX::WorkStatus::ERROR ||
7669
status == TorchWorkNCCLX::WorkStatus::TIMEDOUT ||
7770
status == TorchWorkNCCLX::WorkStatus::COMPLETED) {
@@ -84,7 +77,6 @@ TorchWorkNCCLX::WorkStatus TorchWorkNCCLXQueue::finalize() {
8477
// NOTE: finalize MUST return without holding references to any work object,
8578
// otherwise it may leak object and cause side effects.
8679
stream_work_queues_.clear();
87-
completed_work_queue_.clear();
8880

8981
return status;
9082
}

comms/torchcomms/ncclx/tests/unit/cpp/TorchWorkNCCLXQueueTest.cpp

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,18 +146,14 @@ class TorchWorkNCCLXQueueCommTest : public ::testing::Test {
146146
EXPECT_CALL(*mock_hook_, clear()).Times(times_clear);
147147
}
148148

149-
void checkWorkQueue(bool isMainThread) {
150-
comm_->checkWorkQueue(isMainThread);
149+
void checkWorkQueue() {
150+
comm_->checkWorkQueue();
151151
}
152152

153153
const auto& getStreamWorkQueues() {
154154
return comm_->workq_.stream_work_queues_;
155155
}
156156

157-
const auto& getCompletedWorkQueue() {
158-
return comm_->workq_.completed_work_queue_;
159-
}
160-
161157
cudaEvent_t getAsyncDependencyEvent() {
162158
return comm_->dependency_event_;
163159
}
@@ -180,7 +176,7 @@ class TorchWorkNCCLXQueueCommTest : public ::testing::Test {
180176

181177
TEST_F(TorchWorkNCCLXQueueTest, GarbageCollectEmptyQueue) {
182178
// Test garbage collection on empty queue
183-
auto status = queue_->garbageCollect(false);
179+
auto status = queue_->garbageCollect();
184180
EXPECT_EQ(status, TorchWorkNCCLX::WorkStatus::COMPLETED);
185181
}
186182

@@ -191,9 +187,9 @@ TEST_F(TorchWorkNCCLXQueueTest, FinalizeEmptyQueue) {
191187

192188
TEST_F(TorchWorkNCCLXQueueTest, MultipleGarbageCollectCalls) {
193189
// Multiple garbage collect calls on empty queue should be safe
194-
auto status1 = queue_->garbageCollect(false);
195-
auto status2 = queue_->garbageCollect(false);
196-
auto status3 = queue_->garbageCollect(true);
190+
auto status1 = queue_->garbageCollect();
191+
auto status2 = queue_->garbageCollect();
192+
auto status3 = queue_->garbageCollect();
197193

198194
EXPECT_EQ(status1, TorchWorkNCCLX::WorkStatus::COMPLETED);
199195
EXPECT_EQ(status2, TorchWorkNCCLX::WorkStatus::COMPLETED);
@@ -202,7 +198,7 @@ TEST_F(TorchWorkNCCLXQueueTest, MultipleGarbageCollectCalls) {
202198

203199
TEST_F(TorchWorkNCCLXQueueTest, MultipleFinalizeCallsAfterGarbageCollect) {
204200
// Garbage collect first
205-
auto gc_status = queue_->garbageCollect(false);
201+
auto gc_status = queue_->garbageCollect();
206202
EXPECT_EQ(gc_status, TorchWorkNCCLX::WorkStatus::COMPLETED);
207203

208204
// Multiple finalize calls should be safe
@@ -215,8 +211,8 @@ TEST_F(TorchWorkNCCLXQueueTest, MultipleFinalizeCallsAfterGarbageCollect) {
215211

216212
TEST_F(TorchWorkNCCLXQueueTest, GarbageCollectMainThreadFlag) {
217213
// Test that the isMainThread flag doesn't cause issues on empty queue
218-
auto status1 = queue_->garbageCollect(false);
219-
auto status2 = queue_->garbageCollect(true);
214+
auto status1 = queue_->garbageCollect();
215+
auto status2 = queue_->garbageCollect();
220216

221217
EXPECT_EQ(status1, TorchWorkNCCLX::WorkStatus::COMPLETED);
222218
EXPECT_EQ(status2, TorchWorkNCCLX::WorkStatus::COMPLETED);
@@ -232,17 +228,16 @@ TEST_F(TorchWorkNCCLXQueueTest, ConcurrentGarbageCollectCalls) {
232228
// mutex-protected operations work correctly with multiple calls
233229

234230
for (int i = 0; i < 10; ++i) {
235-
auto status =
236-
queue_->garbageCollect(i % 2 == 0); // Alternate main thread flag
231+
auto status = queue_->garbageCollect();
237232
EXPECT_EQ(status, TorchWorkNCCLX::WorkStatus::COMPLETED);
238233
}
239234
}
240235

241236
TEST_F(TorchWorkNCCLXQueueTest, ConcurrentFinalizeAndGarbageCollect) {
242237
// Test that finalize and garbage collect can be called in sequence safely
243-
auto gc_status = queue_->garbageCollect(false);
238+
auto gc_status = queue_->garbageCollect();
244239
auto finalize_status = queue_->finalize();
245-
auto gc_status2 = queue_->garbageCollect(true);
240+
auto gc_status2 = queue_->garbageCollect();
246241

247242
EXPECT_EQ(gc_status, TorchWorkNCCLX::WorkStatus::COMPLETED);
248243
EXPECT_EQ(finalize_status, TorchWorkNCCLX::WorkStatus::COMPLETED);
@@ -278,7 +273,7 @@ TEST_F(TorchWorkNCCLXQueueTest, QueueCreationAndDestruction) {
278273
EXPECT_NE(queue2, nullptr);
279274

280275
// Test basic operations on new queue
281-
auto status = queue2->garbageCollect(false);
276+
auto status = queue2->garbageCollect();
282277
EXPECT_EQ(status, TorchWorkNCCLX::WorkStatus::COMPLETED);
283278

284279
status = queue2->finalize();
@@ -294,8 +289,8 @@ TEST_F(TorchWorkNCCLXQueueTest, MultipleQueuesIndependent) {
294289
auto queue3 = std::make_unique<TorchWorkNCCLXQueue>();
295290

296291
// Operations on different queues should not interfere
297-
auto status1 = queue_->garbageCollect(false);
298-
auto status2 = queue2->garbageCollect(true);
292+
auto status1 = queue_->garbageCollect();
293+
auto status2 = queue2->garbageCollect();
299294
auto status3 = queue3->finalize();
300295

301296
EXPECT_EQ(status1, TorchWorkNCCLX::WorkStatus::COMPLETED);
@@ -322,12 +317,11 @@ TEST_F(TorchWorkNCCLXQueueCommTest, NoLeakedObjectsAfterFinalize) {
322317
auto work = comm_->send(tensor, 1, true); // async send
323318

324319
// Simulate the timeout thread calling checkWorkQueue
325-
checkWorkQueue(/*isMainThread=*/false);
320+
checkWorkQueue();
326321
// Comm finalize will call the work queue finalize().
327322
comm_->finalize();
328323

329324
EXPECT_EQ(getStreamWorkQueues().size(), 0);
330-
EXPECT_EQ(getCompletedWorkQueue().size(), 0);
331325
}
332326

333327
TEST_F(TorchWorkNCCLXQueueCommTest, NoFailureUnderCudaGraphMode) {
@@ -370,12 +364,11 @@ TEST_F(TorchWorkNCCLXQueueCommTest, NoFailureUnderCudaGraphMode) {
370364
auto work = comm_->send(tensor, 1, true); // async send
371365

372366
// Simulate the timeout thread calling checkWorkQueue
373-
checkWorkQueue(/*isMainThread=*/false);
367+
checkWorkQueue();
374368
// Comm finalize will call the work queue finalize().
375369
comm_->finalize();
376370

377371
EXPECT_EQ(getStreamWorkQueues().size(), 0);
378-
EXPECT_EQ(getCompletedWorkQueue().size(), 0);
379372
}
380373

381374
} // namespace comms

0 commit comments

Comments
 (0)