1616#crazy_2d_src = #ttg.linear <{register = [[0 , 2 ], [2 , 0 ]], lane = [[0 , 8 ], [8 , 0 ], [1 , 0 ], [4 , 0 ], [16 , 0 ]], warp = [[0 , 1 ], [0 , 4 ]], block = []}>
1717#crazy_2d_idx = #ttg.linear <{register = [[2 , 0 ], [0 , 2 ]], lane = [[0 , 8 ], [16 , 0 ], [1 , 0 ], [8 , 0 ], [4 , 0 ]], warp = [[0 , 1 ], [0 , 4 ]], block = []}>
1818
19- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
19+ #broadcasted_lane_1d = #ttg.blocked <{sizePerThread = [1 ], threadsPerWarp = [32 ], warpsPerCTA = [4 ], order = [0 ]}>
20+ #broadcasted_warp_2d = #ttg.blocked <{sizePerThread = [2 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
21+
22+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 32 : i32 } {
2023
2124// Each source element is mapped to a single thread, so we expect one index shuffle.
2225// CHECK-LABEL: @gather_warp_local_trivial
@@ -222,6 +225,59 @@ tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1:
222225 tt.return %0 : tensor <32 x16 xf32 , #crazy_2d_idx >
223226}
224227
228+ // There are 16 elements in the tensor. For each warp, each half-warp is mapped
229+ // to the 16 elements, so it doesn't matter if the second half [16, 32) indexes
230+ // into [0, 16), since they contain the same data.
231+ // CHECK-LABEL: @gather_broadcasted_lane_1d
232+ tt.func private @gather_broadcasted_lane_1d (%arg0: tensor <16 xi32 , #broadcasted_lane_1d >, %arg1: tensor <16 xf32 , #broadcasted_lane_1d >) -> tensor <16 xf32 , #broadcasted_lane_1d > {
233+ // CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0
234+ // CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0
235+
236+ // CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 15
237+ // CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32
238+ // CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31)
239+ %0 = tt.gather %arg1 [%arg0 ] {axis = 0 : i32 } : (tensor <16 xf32 , #broadcasted_lane_1d >, tensor <16 xi32 , #broadcasted_lane_1d >) -> tensor <16 xf32 , #broadcasted_lane_1d >
240+
241+ // CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float
242+ // CHECK-NEXT: ret float [[RES]]
243+ tt.return %0 : tensor <16 xf32 , #broadcasted_lane_1d >
244+ }
245+
246+ // Single gather column with 64 elements, all of which have to fit into a single
247+ // warp, so the whole column is broadcasted across the 4 warps. Each process the
248+ // same data so the warp doesn't matter.
249+ // CHECK-LABEL: @gather_broadcasted_warp_2d
250+ tt.func private @gather_broadcasted_warp_2d (%arg0: tensor <64 x1 xi32 , #broadcasted_warp_2d >, %arg1: tensor <64 x1 xf32 , #broadcasted_warp_2d >) -> tensor <64 x1 xf32 , #broadcasted_warp_2d > {
251+ // CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0
252+ // CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1
253+ // CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0
254+ // CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1
255+
256+ // CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1
257+ // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1
258+ // CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 31
259+
260+ // CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32
261+ // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31)
262+ // CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32
263+ // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID0]], i32 31)
264+
265+ // CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0
266+ // CHECK-NEXT: select i1 [[PICK0]], i32 [[RES0_i32]], i32 [[RES1_i32]]
267+
268+ // CHECK: [[REGID1:%.*]] = and i32 [[IDX1]], 1
269+ // CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1
270+ // CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 31
271+
272+ // CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID1]], i32 31)
273+ // CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31)
274+
275+ // CHECK-NEXT: [[PICK1:%.*]] = icmp eq i32 [[REGID1]], 0
276+ // CHECK-NEXT: select i1 [[PICK1]], i32 [[RES0_i32]], i32 [[RES1_i32]]
277+ %0 = tt.gather %arg1 [%arg0 ] {axis = 0 : i32 } : (tensor <64 x1 xf32 , #broadcasted_warp_2d >, tensor <64 x1 xi32 , #broadcasted_warp_2d >) -> tensor <64 x1 xf32 , #broadcasted_warp_2d >
278+ tt.return %0 : tensor <64 x1 xf32 , #broadcasted_warp_2d >
279+ }
280+
225281// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM
226282// from removing unused function results.
227283tt.func @anchor (%ptr: !llvm.ptr ,
@@ -235,7 +291,11 @@ tt.func @anchor(%ptr: !llvm.ptr,
235291 %arg7: tensor <32 x2 xi32 , #span_2d_cols >,
236292 %arg8: tensor <32 x2 xf32 , #span_2d_cols >,
237293 %arg9: tensor <32 x16 xi32 , #crazy_2d_idx >,
238- %arg10: tensor <32 x16 xf32 , #crazy_2d_src >) {
294+ %arg10: tensor <32 x16 xf32 , #crazy_2d_src >,
295+ %arg11: tensor <16 xi32 , #broadcasted_lane_1d >,
296+ %arg12: tensor <16 xf32 , #broadcasted_lane_1d >,
297+ %arg13: tensor <64 x1 xi32 , #broadcasted_warp_2d >,
298+ %arg14: tensor <64 x1 xf32 , #broadcasted_warp_2d >) {
239299
240300 %0 = tt.call @gather_warp_local_trivial (%arg0 , %arg1 ) : (tensor <32 xi32 , #trivial_layout >, tensor <32 xf32 , #trivial_layout >) -> tensor <32 xf32 , #trivial_layout >
241301 %1 = builtin.unrealized_conversion_cast %0 : tensor <32 xf32 , #trivial_layout > to !llvm.struct <(f32 )>
@@ -265,6 +325,14 @@ tt.func @anchor(%ptr: !llvm.ptr,
265325 %13 = builtin.unrealized_conversion_cast %12 : tensor <32 x16 xf32 , #crazy_2d_idx > to !llvm.struct <(f32 , f32 , f32 , f32 )>
266326 llvm.store volatile %13 , %ptr : !llvm.struct <(f32 , f32 , f32 , f32 )>, !llvm.ptr
267327
328+ %14 = tt.call @gather_broadcasted_lane_1d (%arg11 , %arg12 ) : (tensor <16 xi32 , #broadcasted_lane_1d >, tensor <16 xf32 , #broadcasted_lane_1d >) -> tensor <16 xf32 , #broadcasted_lane_1d >
329+ %15 = builtin.unrealized_conversion_cast %14 : tensor <16 xf32 , #broadcasted_lane_1d > to !llvm.struct <(f32 )>
330+ llvm.store volatile %15 , %ptr : !llvm.struct <(f32 )>, !llvm.ptr
331+
332+ %16 = tt.call @gather_broadcasted_warp_2d (%arg13 , %arg14 ) : (tensor <64 x1 xi32 , #broadcasted_warp_2d >, tensor <64 x1 xf32 , #broadcasted_warp_2d >) -> tensor <64 x1 xf32 , #broadcasted_warp_2d >
333+ %17 = builtin.unrealized_conversion_cast %16 : tensor <64 x1 xf32 , #broadcasted_warp_2d > to !llvm.struct <(f32 , f32 )>
334+ llvm.store volatile %17 , %ptr : !llvm.struct <(f32 , f32 )>, !llvm.ptr
335+
268336 tt.return
269337}
270338
0 commit comments