@@ -1516,3 +1516,54 @@ module attributes {"ttg.num-warps" = 4 : i32} {
15161516 tt.return %7 : tensor <1024 xf32 >
15171517 }
15181518}
1519+ // -----
1520+
1521+ module attributes {" ttg.num-warps" = 4 : i32 } {
1522+ tt.func @propagate_divisibility (%arg0: !tt.ptr <f32 >) -> tensor <1024 xf32 > {
1523+ %c1024_i32 = arith.constant 1024 : i32
1524+ %0 = tt.get_program_id x : i32
1525+ %1 = arith.muli %0 , %c1024_i32 : i32
1526+ %2 = tt.splat %1 : i32 -> tensor <1024 xi32 >
1527+ %3 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1024 x!tt.ptr <f32 >>
1528+ %4 = tt.addptr %3 , %2 {tt.divisibility = 16 : i32 , misc.misc = 3 : i32 } : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 xi32 >
1529+ %5 = tt.load %4 : tensor <1024 x!tt.ptr <f32 >>
1530+ tt.return %5 : tensor <1024 xf32 >
1531+ }
1532+ }
1533+
1534+ // CHECK-LABEL: tt.func @propagate_divisibility(
1535+ // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
1536+ // CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
1537+ // CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
1538+ // CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
1539+ // CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] {tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
1540+ // CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
1541+ // CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr<f32>>
1542+ // CHECK: tt.return %[[VAL_6]] : tensor<1024xf32>
1543+ // CHECK: }
1544+
1545+ // -----
1546+
1547+ module attributes {" ttg.num-warps" = 4 : i32 } {
1548+ tt.func @divisiblity_changeing_dims (%arg0: !tt.ptr <f32 >) -> tensor <1024 x32 xf32 > {
1549+ %c1024_i32 = arith.constant 1024 : i32
1550+ %0 = tt.get_program_id x : i32
1551+ %1 = arith.muli %0 , %c1024_i32 : i32
1552+ %2 = tt.splat %1 : i32 -> tensor <1024 x32 xi32 >
1553+ %3 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1024 x32 x!tt.ptr <f32 >>
1554+ %4 = tt.addptr %3 , %2 {tt.divisibility = dense <[1 , 16 ]> : tensor <2 xi32 >} : tensor <1024 x32 x!tt.ptr <f32 >>, tensor <1024 x32 xi32 >
1555+ %5 = tt.load %4 : tensor <1024 x32 x!tt.ptr <f32 >>
1556+ tt.return %5 : tensor <1024 x32 xf32 >
1557+ }
1558+ }
1559+
1560+ // CHECK-LABEL: tt.func @divisiblity_changeing_dims(
1561+ // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024x32xf32> {
1562+ // CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
1563+ // CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
1564+ // CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
1565+ // CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
1566+ // CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x32x!tt.ptr<f32>>
1567+ // CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x32x!tt.ptr<f32>>
1568+ // CHECK: tt.return %[[VAL_6]] : tensor<1024x32xf32>
1569+ // CHECK: }
0 commit comments