@@ -3034,21 +3034,6 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
30343034 np .testing .assert_allclose (y_tri , y_ref , rtol = 0.01 , atol = 1e-3 )
30353035
30363036
3037- scan_layouts = [
3038- BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3039- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3040- BlockedLayout ([4 , 1 ], [4 , THREADS_PER_WARP // 4 ], [1 , 4 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3041- BlockedLayout ([2 , 2 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3042- BlockedLayout ([2 , 2 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3043- BlockedLayout ([1 , 4 ], [4 , THREADS_PER_WARP // 4 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3044- BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3045- BlockedLayout ([4 , 1 ], [4 , THREADS_PER_WARP // 4 ], [1 , 4 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3046- BlockedLayout ([2 , 2 ], [4 , THREADS_PER_WARP // 4 ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3047- BlockedLayout ([2 , 2 ], [8 , THREADS_PER_WARP // 8 ], [2 , 2 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3048- BlockedLayout ([1 , 2 ], [1 , THREADS_PER_WARP // 1 ], [1 , 4 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
3049- ]
3050-
3051-
30523037def test_no_rematerialization_op ():
30533038
30543039 if torch .version .hip :
@@ -3094,73 +3079,6 @@ def kernel(
30943079 assert compiled_kernel .asm ["ttgir" ].count ('"tt.reduce"' ) == 1 , "we shouldn't rematerialize tt.reduce"
30953080
30963081
3097- @pytest .mark .parametrize ("M, N" , [[32 , 16 ], [32 , 32 ], [32 , 64 ], [64 , 32 ]])
3098- @pytest .mark .parametrize ("src_layout" , scan_layouts )
3099- @pytest .mark .parametrize ("axis" , [0 , 1 ])
3100- @pytest .mark .parametrize ("add_overflow_check" , [False , True ])
3101- def test_scan_layouts (M , N , src_layout , axis , add_overflow_check , device , tmp_path : pathlib .Path ):
3102-
3103- overflow_check = """
3104- %17 = arith.extsi %arg2 : i32 to i64
3105- %18 = arith.extsi %arg3 : i32 to i64
3106- %19 = arith.addi %17, %18 : i64
3107- %i32.min = arith.constant -2147483648: i64
3108- %i32.max = arith.constant 2147483647: i64
3109- %20 = arith.cmpi slt, %19, %i32.max : i64
3110- %21 = arith.cmpi sge, %19, %i32.min : i64
3111- %22 = arith.andi %20, %21 : i1
3112- tt.assert %22, "overflow detected" : i1
3113- """
3114-
3115- ir = f"""
3116- #blocked = { src_layout }
3117- module attributes {{"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = { THREADS_PER_WARP } : i32}} {{
3118- tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
3119- %cst = arith.constant dense<{ N } > : tensor<{ M } x1xi32, #blocked>
3120- %0 = tt.make_range {{end = { M } : i32, start = 0 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>>
3121- %1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{ M } xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{ M } x1xi32, #blocked>
3122- %2 = arith.muli %1, %cst : tensor<{ M } x1xi32, #blocked>
3123- %3 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<{ M } x1x!tt.ptr<i32>, #blocked>
3124- %4 = tt.addptr %3, %2 : tensor<{ M } x1x!tt.ptr<i32>, #blocked>, tensor<{ M } x1xi32, #blocked>
3125- %5 = tt.make_range {{end = { N } : i32, start = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>>
3126- %6 = tt.expand_dims %5 {{axis = 0 : i32}} : tensor<{ N } xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{ N } xi32, #blocked>
3127- %7 = tt.broadcast %4 : tensor<{ M } x1x!tt.ptr<i32>, #blocked> -> tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>
3128- %8 = tt.broadcast %6 : tensor<1x{ N } xi32, #blocked> -> tensor<{ M } x{ N } xi32, #blocked>
3129- %9 = tt.addptr %7, %8 : tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>, tensor<{ M } x{ N } xi32, #blocked>
3130- %10 = tt.load %9 : tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>
3131- %11 = "tt.scan"(%10) <{{axis = { axis } : i32, reverse = false}}> ({{
3132- ^bb0(%arg2: i32, %arg3: i32):
3133- %16 = arith.addi %arg2, %arg3 : i32{ overflow_check if add_overflow_check else "" }
3134- tt.scan.return %16 : i32
3135- }}) : (tensor<{ M } x{ N } xi32, #blocked>) -> tensor<{ M } x{ N } xi32, #blocked>
3136- %12 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<{ M } x1x!tt.ptr<i32>, #blocked>
3137- %13 = tt.addptr %12, %2 : tensor<{ M } x1x!tt.ptr<i32>, #blocked>, tensor<{ M } x1xi32, #blocked>
3138- %14 = tt.broadcast %13 : tensor<{ M } x1x!tt.ptr<i32>, #blocked> -> tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>
3139- %15 = tt.addptr %14, %8 : tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>, tensor<{ M } x{ N } xi32, #blocked>
3140- tt.store %15, %11 : tensor<{ M } x{ N } x!tt.ptr<i32>, #blocked>
3141- tt.return
3142- }}
3143- }}
3144- """
3145-
3146- temp_file = tmp_path / "test_scan_layouts.ttgir"
3147- temp_file .write_text (ir )
3148- kernel = triton .compile (str (temp_file ))
3149-
3150- rs = RandomState (17 )
3151- x = rs .randint (- 100 , 100 , (M , N )).astype ('int32' )
3152-
3153- z = np .zeros ((M , N )).astype ('int32' )
3154- x_tri = torch .tensor (x , device = device )
3155- z_tri = torch .tensor (z , device = device )
3156-
3157- kernel [(1 , 1 , 1 )](x_tri , z_tri )
3158-
3159- z_ref = np .cumsum (x , axis = axis )
3160-
3161- np .testing .assert_equal (z_ref , z_tri .cpu ().numpy ())
3162-
3163-
31643082layouts = [
31653083 BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [1 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
31663084 BlockedLayout ([1 , 4 ], [8 , THREADS_PER_WARP // 8 ], [4 , 1 ], [0 , 1 ], [1 , 1 ], [1 , 1 ], [0 , 1 ]),
0 commit comments