@@ -76,15 +76,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
76
76
// CHECK-SAME: %[[A:.*]]: !llvm.struct<(f32, f32, f32, f32)>, %[[B:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, reqd_work_group_size = array<i32: 16, 1, 1>} {
77
77
tt.func @dot_f32_tf32_tf32_f32_1 (%a: tensor <8 x8 xf32 , #dot_operand_a >, %b: tensor <8 x16 xf32 , #dot_operand_b >, %c: tensor <8 x16 xf32 , #dpas >) {
78
78
// COM: To simplify, only check RTNE and its usage for the last element of A, B, C
79
- // CHECK: %[[A_LAST_VAL :.*]] = llvm.extractvalue %[[A]][3]
80
- // CHECK: %[[A_RTNE_VAL :.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf( %[[A_LAST_VAL]])
81
- // CHECK: %[[A_0 :.*]] = llvm.insertelement %[[A_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<4xf32>
82
- // CHECK: %[[B_LAST_VAL :.*]] = llvm.extractvalue %[[B]][7]
83
- // CHECK: %[[B_RTNE_VAL :.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf( %[[B_LAST_VAL]])
84
- // CHECK: %[[B_0 :.*]] = llvm.insertelement %[[B_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
79
+ // CHECK: %[[A_EXTR_LAST_VAL :.*]] = llvm.extractvalue %[[A]][3]
80
+ // CHECK: %[[A_LAST_VAL :.*]] = llvm.insertelement %[[A_EXTR_LAST_VAL]], %{{.*}} : vector<4xf32>
81
+ // CHECK: %[[A_RTNE_VAL :.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELDv4_f( %[[A_LAST_VAL]])
82
+ // CHECK: %[[B_EXTR_LAST_VAL :.*]] = llvm.extractvalue %[[B]][7]
83
+ // CHECK: %[[B_LAST_VAL :.*]] = llvm.insertelement %[[B_EXTR_LAST_VAL]], %{{.*}} : vector<8xf32>
84
+ // CHECK: %[[B_RTNE_VAL :.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELDv8_f( %[[B_LAST_VAL]])
85
85
// CHECK: %[[C_LAST_VAL:.*]] = llvm.extractvalue %[[C]][7]
86
86
// CHECK: %[[C_0:.*]] = llvm.insertelement %[[C_LAST_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
87
- // CHECK: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_fDv8_fS0_i(%{{.*}}, %[[A_0 ]], %[[B_0 ]], %[[C_0]], %{{.*}}} : (i32, vector<4xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
87
+ // CHECK: llvm.call spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv4_fDv8_fS0_i(%{{.*}}, %[[A_RTNE_VAL ]], %[[B_RTNE_VAL ]], %[[C_0]], %{{.*}}} : (i32, vector<4xf32>, vector<8xf32>, vector<8xf32>, i32) -> vector<8xf32>
88
88
%0 = tt.dot %a , %b , %c , inputPrecision = tf32 : tensor <8 x8 xf32 , #dot_operand_a > * tensor <8 x16 xf32 , #dot_operand_b > -> tensor <8 x16 xf32 , #dpas >
89
89
tt.return
90
90
}
0 commit comments