@@ -170,39 +170,31 @@ getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
170170 NVVM::LdStMatrixShapeAttr shape,
171171 NVVM::LdStMatrixEltType eltType) {
172172 if (shape.getM () == 8 && shape.getN () == 8 ) {
173- if (eltType == NVVM::LdStMatrixEltType::B16) {
174- if (layout == NVVM::MMALayout::row) {
175- switch (num) {
176- case 1 :
177- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16;
178- case 2 :
179- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16;
180- case 4 :
181- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16;
182- }
183- } else {
184- switch (num) {
185- case 1 :
186- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
187- case 2 :
188- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
189- case 4 :
190- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
191- }
192- }
173+ switch (num) {
174+ case 1 :
175+ return (layout == NVVM::MMALayout::row)
176+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
177+ : llvm::Intrinsic::
178+ nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
179+ case 2 :
180+ return (layout == NVVM::MMALayout::row)
181+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
182+ : llvm::Intrinsic::
183+ nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
184+ case 4 :
185+ return (layout == NVVM::MMALayout::row)
186+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
187+ : llvm::Intrinsic::
188+ nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
193189 }
194190 } else if (shape.getM () == 16 && shape.getN () == 8 ) {
195- if (eltType == NVVM::LdStMatrixEltType::B8) {
196- if (layout == NVVM::MMALayout::col) {
197- switch (num) {
198- case 1 :
199- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
200- case 2 :
201- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
202- case 4 :
203- return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
204- }
205- }
191+ switch (num) {
192+ case 1 :
193+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
194+ case 2 :
195+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
196+ case 4 :
197+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
206198 }
207199 }
208200 llvm_unreachable (" unknown stmatrix kind" );
0 commit comments