@@ -77,15 +77,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
7777 // CHECK-SAME: %[[C:.*]]: !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>) attributes {intel_reqd_sub_group_size = 32 : i32, triton_gen.max_work_group_size = array<i32: 32, 1, 1>} {
7878 tt.func @dot_f32_tf32_tf32_f32_1 (%a: tensor <8 x8 xf32 , #dot_operand_a >, %b: tensor <8 x16 xf32 , #dot_operand_b >, %c: tensor <8 x16 xf32 , #dpas >) {
7979 // COM: To simplify, only check RTNE and its usage for the last element of A, B, C
80- // CHECK %[[A_LAST_VAL:.*]] = llvm.extractvalue %[[A]][3]
81- // CHECK %[[A_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[A_LAST_VAL]])
82- // CHECK %[[A_0:.*]] = llvm.insertelement %[[A_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<4xf32>
83- // CHECK %[[B_LAST_VAL:.*]] = llvm.extractvalue %[[B]][7]
84- // CHECK %[[B_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[B_LAST_VAL]])
85- // CHECK %[[B_0:.*]] = llvm.insertelement %[[B_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
86- // CHECK %[[C_LAST_VAL:.*]] = llvm.extractvalue %[[C]][7]
87- // CHECK %[[C_0:.*]] = llvm.insertelement %[[C_LAST_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
88- // CHECK : llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv4_fDv8_fS0_(%[[A_0]], %[[B_0]], %[[C_0]]) {{.*}} : (vector<4xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
80+ // CHECK: %[[A_LAST_VAL:.*]] = llvm.extractvalue %[[A]][3]
81+ // CHECK: %[[A_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[A_LAST_VAL]])
82+ // CHECK: %[[A_0:.*]] = llvm.insertelement %[[A_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<4xf32>
83+ // CHECK: %[[B_LAST_VAL:.*]] = llvm.extractvalue %[[B]][7]
84+ // CHECK: %[[B_RTNE_VAL:.*]] = llvm.call spir_funccc @_Z25__spirv_RoundFToTF32INTELf(%[[B_LAST_VAL]])
85+ // CHECK: %[[B_0:.*]] = llvm.insertelement %[[B_RTNE_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
86+ // CHECK: %[[C_LAST_VAL:.*]] = llvm.extractvalue %[[C]][7]
87+ // CHECK: %[[C_0:.*]] = llvm.insertelement %[[C_LAST_VAL]], %{{.*}}{{\[}}%{{.*}} : i32] : vector<8xf32>
88+ // CHECK: llvm.call spir_funccc @_Z39intel_sub_group_tf32_tf32_matrix_mad_k8Dv4_fDv8_fS0_(%[[A_0]], %[[B_0]], %[[C_0]]) {{.*}} : (vector<4xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
8989 %0 = tt.dot %a , %b , %c , inputPrecision = tf32 : tensor <8 x8 xf32 , #dot_operand_a > * tensor <8 x16 xf32 , #dot_operand_b > -> tensor <8 x16 xf32 , #dpas >
9090 tt.return
9191 }
0 commit comments