@@ -559,13 +559,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
559559#smem = #ttg.shared_memory
560560module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " cuda:90" , " ttg.threads-per-warp" = 32 : i32 } {
561561 // CHECK-LABEL: _kernel_matmul_dependency
562- tt.func public @_kernel_matmul_dependency (%arg0: tensor <128 x128 x!tt.ptr <f8E4M3FNUZ >, #blocked >, %arg1: !tt.ptr <f8E4M3FNUZ > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 , %arg4: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg5: tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>) attributes {noinline = false } {
562+ tt.func public @_kernel_matmul_dependency (%arg0: tensor <128 x128 x!tt.ptr <f8E4M3FN >, #blocked >, %arg1: !tt.ptr <f8E4M3FN > {tt.divisibility = 16 : i32 }, %arg2: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg3: i32 , %arg4: !tt.ptr <f32 > {tt.divisibility = 16 : i32 }, %arg5: tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>) attributes {noinline = false } {
563563 %cst = arith.constant dense <0 > : tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>
564564 %cst_0 = arith.constant 1.000000e+00 : f32
565565 %c8_i32 = arith.constant 8 : i32
566566 %cst_1 = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
567567 %0 = tt.make_range {end = 128 : i32 , start = 0 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
568- %1 = tt.splat %arg1 : !tt.ptr <f8E4M3FNUZ > -> tensor <128 x128 x!tt.ptr <f8E4M3FNUZ >, #blocked1 >
568+ %1 = tt.splat %arg1 : !tt.ptr <f8E4M3FN > -> tensor <128 x128 x!tt.ptr <f8E4M3FN >, #blocked1 >
569569 %2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args (%arg7 = %c8_i32 , %arg8 = %c8_i32 , %arg9 = %cst_1 , %arg10 = %arg5 ) -> (i32 , i32 , tensor <128 x128 xf32 , #mma >, tensor <128 xi32 , #ttg.slice <{dim = 0 , parent = #blocked1 }>>) : i32 {
570570 %3 = arith.addi %arg7 , %c8_i32 : i32
571571 %4 = arith.cmpi eq , %3 , %c8_i32 : i32
@@ -586,12 +586,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
586586 %9 = arith.addi %8 , %0 : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>>
587587 %10 = tt.expand_dims %9 {axis = 1 : i32 } : tensor <128 xi32 , #ttg.slice <{dim = 1 , parent = #blocked1 }>> -> tensor <128 x1 xi32 , #blocked1 >
588588 %11 = tt.broadcast %10 : tensor <128 x1 xi32 , #blocked1 > -> tensor <128 x128 xi32 , #blocked1 >
589- %12 = tt.addptr %1 , %11 : tensor <128 x128 x!tt.ptr <f8E4M3FNUZ >, #blocked1 >, tensor <128 x128 xi32 , #blocked1 >
590- %13 = tt.load %arg0 : tensor <128 x128 x!tt.ptr <f8E4M3FNUZ >, #blocked >
591- %14 = ttg.local_alloc %13 : (tensor <128 x 128 xf 8 E 4 M 3 FNUZ , #blocked >) -> !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FNUZ , #shared , #smem >
592- %15 = tt.load %12 : tensor <128 x128 x!tt.ptr <f8E4M3FNUZ >, #blocked1 >
593- %16 = ttg.local_alloc %15 : (tensor <128 x 128 xf 8 E 4 M 3 FNUZ , #blocked1 >) -> !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FNUZ , #shared1 , #smem >
594- %17 = ttng.warp_group_dot %14 , %16 , %arg9 {inputPrecision = 0 : i32 , maxNumImpreciseAcc = 1073741824 : i32 } : !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FNUZ , #shared , #smem > * !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FNUZ , #shared1 , #smem > -> tensor <128 x128 xf32 , #mma >
589+ %12 = tt.addptr %1 , %11 : tensor <128 x128 x!tt.ptr <f8E4M3FN >, #blocked1 >, tensor <128 x128 xi32 , #blocked1 >
590+ %13 = tt.load %arg0 : tensor <128 x128 x!tt.ptr <f8E4M3FN >, #blocked >
591+ %14 = ttg.local_alloc %13 : (tensor <128 x 128 xf 8 E 4 M 3 FN , #blocked >) -> !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FN , #shared , #smem >
592+ %15 = tt.load %12 : tensor <128 x128 x!tt.ptr <f8E4M3FN >, #blocked1 >
593+ %16 = ttg.local_alloc %15 : (tensor <128 x 128 xf 8 E 4 M 3 FN , #blocked1 >) -> !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FN , #shared1 , #smem >
594+ %17 = ttng.warp_group_dot %14 , %16 , %arg9 {inputPrecision = 0 : i32 , maxNumImpreciseAcc = 1073741824 : i32 } : !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FN , #shared , #smem > * !ttg.memdesc <128 x 128 xf 8 E 4 M 3 FN , #shared1 , #smem > -> tensor <128 x128 xf32 , #mma >
595595 %18 = tt.splat %7 : f32 -> tensor <128 x128 xf32 , #mma >
596596 %19 = arith.mulf %17 , %18 : tensor <128 x128 xf32 , #mma >
597597 %20 = scf.if %6 -> (tensor <128 x128 xf32 , #mma >) {
0 commit comments