@@ -895,3 +895,49 @@ util.func public @multiAffinityTrip(%input_timepoint: !stream.timepoint, %input_
895895 // CHECK: util.return %[[TIMEPOINT_B]]
896896 util.return %timepoint_b , %output_handle : !stream.timepoint , i64
897897}
898+
899+ // -----
900+
901+ // A multi-affinity test with a larger number of devices. This IR used to crash.
902+
903+ // CHECK-LABEL: multiAffinityLotsOfDevices
904+ util.func public @multiAffinityLotsOfDevices (%wait_timepoint: !stream.timepoint ) {
905+ // Constant from 'device_a' is used on a multiple devices.
906+ // CHECK: stream.resource.constants on(#hal.device.optimal<[
907+ // CHECK-SAME: #hal.device.promise<@device_a>, #hal.device.promise<@device_b>,
908+ // CHECK-SAME: #hal.device.promise<@device_c>, #hal.device.promise<@device_d>,
909+ // CHECK-SAME: #hal.device.promise<@device_e>, #hal.device.promise<@device_f>,
910+ // CHECK-SAME: #hal.device.promise<@device_g>, #hal.device.promise<@device_h>
911+ // CHECK-SAME: ]>) :
912+ // CHECK-NEXT: !stream.resource<constant>{%c16} = dense<4>
913+ %c16 = arith.constant 16 : index
914+ %result_a , %result_timepoint_a = stream.async.execute on (#hal.device.promise <@device_a >) await (%wait_timepoint ) => with () -> (!stream.resource <constant >{%c16 }) {
915+ %cst16 = stream.async.constant : !stream.resource <constant >{%c16 } = dense <4 > : tensor <4 x2 xi16 >
916+ stream.yield %cst16 : !stream.resource <constant >{%c16 }
917+ } => !stream.timepoint
918+
919+ // Result from 'device_a' used on a large number of other devices.
920+ %result_b , %result_timepoint_b = stream.async.execute on (#hal.device.promise <@device_b >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
921+ stream.yield %capture : !stream.resource <constant >{%c16 }
922+ } => !stream.timepoint
923+ %result_c , %result_timepoint_c = stream.async.execute on (#hal.device.promise <@device_c >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
924+ stream.yield %capture : !stream.resource <constant >{%c16 }
925+ } => !stream.timepoint
926+ %result_d , %result_timepoint_d = stream.async.execute on (#hal.device.promise <@device_d >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
927+ stream.yield %capture : !stream.resource <constant >{%c16 }
928+ } => !stream.timepoint
929+ %result_e , %result_timepoint_e = stream.async.execute on (#hal.device.promise <@device_e >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
930+ stream.yield %capture : !stream.resource <constant >{%c16 }
931+ } => !stream.timepoint
932+ %result_f , %result_timepoint_f = stream.async.execute on (#hal.device.promise <@device_f >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
933+ stream.yield %capture : !stream.resource <constant >{%c16 }
934+ } => !stream.timepoint
935+ %result_g , %result_timepoint_g = stream.async.execute on (#hal.device.promise <@device_g >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
936+ stream.yield %capture : !stream.resource <constant >{%c16 }
937+ } => !stream.timepoint
938+ %result_h , %result_timepoint_h = stream.async.execute on (#hal.device.promise <@device_h >) await (%result_timepoint_a ) => with (%result_a as %capture: !stream.resource <constant >{%c16 }) -> (!stream.resource <constant >{%c16 }) {
939+ stream.yield %capture : !stream.resource <constant >{%c16 }
940+ } => !stream.timepoint
941+
942+ util.return
943+ }
0 commit comments