Skip to content

Commit 7460c75

Browse files
jvandebonamathewc
authored andcommitted
[MTIA] Ensure correct stream behavior for input_buffer add autograd on MTIA (pytorch#149433)
Test Plan: CI Differential Revision: D71414498 Pull Request resolved: pytorch#149433 Approved by: https://github.com/albanD
1 parent e9980dc commit 7460c75

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

torch/csrc/autograd/input_buffer.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,25 +144,25 @@ void InputBuffer::add(
144144
// (2a) Uses the consumer's stream as the accumulation stream
145145
// (2b) Syncs the accumulation stream with the producer's stream (if
146146
// different) (2c) Accumulates.
147-
// (3) var is a CUDA/privateuse1 variable and it shares a device with the
148-
// consumer but not the producer:
147+
// (3) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
148+
// the consumer but not the producer:
149149
// (3a) Uses the consumer's stream as the accumulation stream
150150
// (3b) Syncs the accumulation stream with the consumer device's default
151151
// stream (3c) Accumulates.
152-
// (4) var is a CUDA/privateuse1 variable and it shares a device with the
153-
// producer but not the consumer:
152+
// (4) var is a CUDA/MTIA/privateuse1 variable and it shares a device with
153+
// the producer but not the consumer:
154154
// (4a) Uses the producer device's default stream as the accumulation
155155
// stream (4b) Syncs the accumulation stream with the producer's
156156
// stream (4c) Accumulates.
157-
// (5) var is a CUDA/privateuse1 variable and it does not share a device with
158-
// the consumer or producer.
157+
// (5) var is a CUDA/MTIA/privateuse1 variable and it does not share a device
158+
// with the consumer or producer.
159159
// Accumulation happens on the var device's default stream.
160160

161161
auto const device = device_of(var);
162162
TORCH_INTERNAL_ASSERT(device.has_value());
163163
std::optional<c10::Stream> opt_accumulate_stream = std::nullopt;
164164
const auto device_type = device->type();
165-
if (device->is_cuda() || device->is_privateuseone()) {
165+
if (device->is_cuda() || device->is_mtia() || device->is_privateuseone()) {
166166
const auto on_producer =
167167
opt_producer_stream && device == opt_producer_stream->device();
168168
const auto on_consumer =

0 commit comments

Comments
 (0)