@@ -1516,3 +1516,54 @@ module attributes {"ttg.num-warps" = 4 : i32} {
1516
1516
tt.return %7 : tensor <1024 xf32 >
1517
1517
}
1518
1518
}
1519
+ // -----
1520
+
1521
+ module attributes {" ttg.num-warps" = 4 : i32 } {
1522
+ tt.func @propagate_divisibility (%arg0: !tt.ptr <f32 >) -> tensor <1024 xf32 > {
1523
+ %c1024_i32 = arith.constant 1024 : i32
1524
+ %0 = tt.get_program_id x : i32
1525
+ %1 = arith.muli %0 , %c1024_i32 : i32
1526
+ %2 = tt.splat %1 : i32 -> tensor <1024 xi32 >
1527
+ %3 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1024 x!tt.ptr <f32 >>
1528
+ %4 = tt.addptr %3 , %2 {tt.divisibility = 16 : i32 , misc.misc = 3 : i32 } : tensor <1024 x!tt.ptr <f32 >>, tensor <1024 xi32 >
1529
+ %5 = tt.load %4 : tensor <1024 x!tt.ptr <f32 >>
1530
+ tt.return %5 : tensor <1024 xf32 >
1531
+ }
1532
+ }
1533
+
1534
+ // CHECK-LABEL: tt.func @propagate_divisibility(
1535
+ // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
1536
+ // CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
1537
+ // CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
1538
+ // CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
1539
+ // CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] {tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
1540
+ // CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
1541
+ // CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr<f32>>
1542
+ // CHECK: tt.return %[[VAL_6]] : tensor<1024xf32>
1543
+ // CHECK: }
1544
+
1545
+ // -----
1546
+
1547
+ module attributes {" ttg.num-warps" = 4 : i32 } {
1548
+ tt.func @divisiblity_changeing_dims (%arg0: !tt.ptr <f32 >) -> tensor <1024 x32 xf32 > {
1549
+ %c1024_i32 = arith.constant 1024 : i32
1550
+ %0 = tt.get_program_id x : i32
1551
+ %1 = arith.muli %0 , %c1024_i32 : i32
1552
+ %2 = tt.splat %1 : i32 -> tensor <1024 x32 xi32 >
1553
+ %3 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <1024 x32 x!tt.ptr <f32 >>
1554
+ %4 = tt.addptr %3 , %2 {tt.divisibility = dense <[1 , 16 ]> : tensor <2 xi32 >} : tensor <1024 x32 x!tt.ptr <f32 >>, tensor <1024 x32 xi32 >
1555
+ %5 = tt.load %4 : tensor <1024 x32 x!tt.ptr <f32 >>
1556
+ tt.return %5 : tensor <1024 x32 xf32 >
1557
+ }
1558
+ }
1559
+
1560
+ // CHECK-LABEL: tt.func @divisiblity_changeing_dims(
1561
+ // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024x32xf32> {
1562
+ // CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
1563
+ // CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
1564
+ // CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
1565
+ // CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
1566
+ // CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x32x!tt.ptr<f32>>
1567
+ // CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x32x!tt.ptr<f32>>
1568
+ // CHECK: tt.return %[[VAL_6]] : tensor<1024x32xf32>
1569
+ // CHECK: }
0 commit comments