@@ -144,25 +144,25 @@ void InputBuffer::add(
144144 // (2a) Uses the consumer's stream as the accumulation stream
145145 // (2b) Syncs the accumulation stream with the producer's stream (if
146146 // different) (2c) Accumulates.
147- // (3) var is a CUDA/privateuse1 variable and it shares a device with the
148- // consumer but not the producer:
147+ // (3) var is a CUDA/MTIA/ privateuse1 variable and it shares a device with
148+ // the consumer but not the producer:
149149 // (3a) Uses the consumer's stream as the accumulation stream
150150 // (3b) Syncs the accumulation stream with the consumer device's default
151151 // stream (3c) Accumulates.
152- // (4) var is a CUDA/privateuse1 variable and it shares a device with the
153- // producer but not the consumer:
152+ // (4) var is a CUDA/MTIA/ privateuse1 variable and it shares a device with
153+ // the producer but not the consumer:
154154 // (4a) Uses the producer device's default stream as the accumulation
155155 // stream (4b) Syncs the accumulation stream with the producer's
156156 // stream (4c) Accumulates.
157- // (5) var is a CUDA/privateuse1 variable and it does not share a device with
158- // the consumer or producer.
157+ // (5) var is a CUDA/MTIA/ privateuse1 variable and it does not share a device
158+ // with the consumer or producer.
159159 // Accumulation happens on the var device's default stream.
160160
161161 auto const device = device_of (var);
162162 TORCH_INTERNAL_ASSERT (device.has_value ());
163163 std::optional<c10::Stream> opt_accumulate_stream = std::nullopt ;
164164 const auto device_type = device->type ();
165- if (device->is_cuda () || device->is_privateuseone ()) {
165+ if (device->is_cuda () || device->is_mtia () || device-> is_privateuseone ()) {
166166 const auto on_producer =
167167 opt_producer_stream && device == opt_producer_stream->device ();
168168 const auto on_consumer =
0 commit comments