Skip to content

Commit fdcd945

Browse files
authored
Fix missing return in DeviceOptimalAttr::joinOR (iree-org#21228)
This code is currently using a `MapVector` iterator after insertion, which invalidates it. This _happens_ to work if the insertion doesn't trigger a reallocation. fixes iree-org#21203
1 parent 324be39 commit fdcd945

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,7 @@ DeviceOptimalAttr::joinOR(IREE::Stream::AffinityAttr other) const {
12311231
if (it == affinitySet.end()) {
12321232
// New device entry.
12331233
affinitySet.insert({otherDeviceAttr, affinityAttr});
1234+
return true;
12341235
}
12351236
// OR in with existing entry.
12361237
auto joinedAttr = it->second.joinOR(other);

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,3 +895,49 @@ util.func public @multiAffinityTrip(%input_timepoint: !stream.timepoint, %input_
895895
// CHECK: util.return %[[TIMEPOINT_B]]
896896
util.return %timepoint_b, %output_handle : !stream.timepoint, i64
897897
}
898+
899+
// -----
900+
901+
// A multi-affinity test with a larger number of devices. This IR used to crash.
902+
903+
// CHECK-LABEL: multiAffinityLotsOfDevices
904+
util.func public @multiAffinityLotsOfDevices(%wait_timepoint: !stream.timepoint) {
905+
// Constant from 'device_a' is used on a multiple devices.
906+
// CHECK: stream.resource.constants on(#hal.device.optimal<[
907+
// CHECK-SAME: #hal.device.promise<@device_a>, #hal.device.promise<@device_b>,
908+
// CHECK-SAME: #hal.device.promise<@device_c>, #hal.device.promise<@device_d>,
909+
// CHECK-SAME: #hal.device.promise<@device_e>, #hal.device.promise<@device_f>,
910+
// CHECK-SAME: #hal.device.promise<@device_g>, #hal.device.promise<@device_h>
911+
// CHECK-SAME: ]>) :
912+
// CHECK-NEXT: !stream.resource<constant>{%c16} = dense<4>
913+
%c16 = arith.constant 16 : index
914+
%result_a, %result_timepoint_a = stream.async.execute on(#hal.device.promise<@device_a>) await(%wait_timepoint) => with() -> (!stream.resource<constant>{%c16}) {
915+
%cst16 = stream.async.constant : !stream.resource<constant>{%c16} = dense<4> : tensor<4x2xi16>
916+
stream.yield %cst16 : !stream.resource<constant>{%c16}
917+
} => !stream.timepoint
918+
919+
// Result from 'device_a' used on a large number of other devices.
920+
%result_b, %result_timepoint_b = stream.async.execute on(#hal.device.promise<@device_b>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
921+
stream.yield %capture : !stream.resource<constant>{%c16}
922+
} => !stream.timepoint
923+
%result_c, %result_timepoint_c = stream.async.execute on(#hal.device.promise<@device_c>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
924+
stream.yield %capture : !stream.resource<constant>{%c16}
925+
} => !stream.timepoint
926+
%result_d, %result_timepoint_d = stream.async.execute on(#hal.device.promise<@device_d>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
927+
stream.yield %capture : !stream.resource<constant>{%c16}
928+
} => !stream.timepoint
929+
%result_e, %result_timepoint_e = stream.async.execute on(#hal.device.promise<@device_e>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
930+
stream.yield %capture : !stream.resource<constant>{%c16}
931+
} => !stream.timepoint
932+
%result_f, %result_timepoint_f = stream.async.execute on(#hal.device.promise<@device_f>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
933+
stream.yield %capture : !stream.resource<constant>{%c16}
934+
} => !stream.timepoint
935+
%result_g, %result_timepoint_g = stream.async.execute on(#hal.device.promise<@device_g>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
936+
stream.yield %capture : !stream.resource<constant>{%c16}
937+
} => !stream.timepoint
938+
%result_h, %result_timepoint_h = stream.async.execute on(#hal.device.promise<@device_h>) await(%result_timepoint_a) => with(%result_a as %capture: !stream.resource<constant>{%c16}) -> (!stream.resource<constant>{%c16}) {
939+
stream.yield %capture : !stream.resource<constant>{%c16}
940+
} => !stream.timepoint
941+
942+
util.return
943+
}

0 commit comments

Comments
 (0)