Skip to content

Commit 461c7d9

Browse files
committed
Simplifier the structure of getStMatrixIntrinsicId
1 parent f964ba8 commit 461c7d9

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -170,39 +170,31 @@ getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
170170
NVVM::LdStMatrixShapeAttr shape,
171171
NVVM::LdStMatrixEltType eltType) {
172172
if (shape.getM() == 8 && shape.getN() == 8) {
173-
if (eltType == NVVM::LdStMatrixEltType::B16) {
174-
if (layout == NVVM::MMALayout::row) {
175-
switch (num) {
176-
case 1:
177-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16;
178-
case 2:
179-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16;
180-
case 4:
181-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16;
182-
}
183-
} else {
184-
switch (num) {
185-
case 1:
186-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
187-
case 2:
188-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
189-
case 4:
190-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
191-
}
192-
}
173+
switch (num) {
174+
case 1:
175+
return (layout == NVVM::MMALayout::row)
176+
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
177+
: llvm::Intrinsic::
178+
nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
179+
case 2:
180+
return (layout == NVVM::MMALayout::row)
181+
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
182+
: llvm::Intrinsic::
183+
nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
184+
case 4:
185+
return (layout == NVVM::MMALayout::row)
186+
? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
187+
: llvm::Intrinsic::
188+
nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
193189
}
194190
} else if (shape.getM() == 16 && shape.getN() == 8) {
195-
if (eltType == NVVM::LdStMatrixEltType::B8) {
196-
if (layout == NVVM::MMALayout::col) {
197-
switch (num) {
198-
case 1:
199-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
200-
case 2:
201-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
202-
case 4:
203-
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
204-
}
205-
}
191+
switch (num) {
192+
case 1:
193+
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
194+
case 2:
195+
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
196+
case 4:
197+
return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
206198
}
207199
}
208200
llvm_unreachable("unknown stmatrix kind");

0 commit comments

Comments
 (0)