@@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
135135 llvm_unreachable (" unsupported vote kind" );
136136}
137137
138- // / Return the intrinsic ID associated with ldmatrix for the given paramters.
139- static llvm::Intrinsic::ID getLdMatrixIntrinsicId (NVVM::MMALayout layout,
140- int32_t num) {
141- if (layout == NVVM::MMALayout::row) {
138+ static llvm::Intrinsic::ID
139+ getLdMatrixIntrinsicId (NVVM::MMALayout layout, int32_t num,
140+ NVVM::LdStMatrixShapeAttr shape,
141+ NVVM::LdStMatrixEltType eltType) {
142+ if (shape.getM () == 8 && shape.getN () == 8 ) {
142143 switch (num) {
143144 case 1 :
144- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
145+ return (layout == NVVM::MMALayout::row)
146+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
147+ : llvm::Intrinsic::
148+ nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
145149 case 2 :
146- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
150+ return (layout == NVVM::MMALayout::row)
151+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
152+ : llvm::Intrinsic::
153+ nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
147154 case 4 :
148- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
149- default :
150- llvm_unreachable (" unsupported number of matrix" );
155+ return (layout == NVVM::MMALayout::row)
156+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
157+ : llvm::Intrinsic::
158+ nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
151159 }
152-
153- } else {
154- switch (num) {
155- case 1 :
156- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
157- case 2 :
158- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
159- case 4 :
160- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
161- default :
162- llvm_unreachable (" unsupported number of matrix" );
160+ } else if (shape.getM () == 8 && shape.getN () == 16 ) {
161+ if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
162+ switch (num) {
163+ case 1 :
164+ return llvm::Intrinsic::
165+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
166+ case 2 :
167+ return llvm::Intrinsic::
168+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
169+ case 4 :
170+ return llvm::Intrinsic::
171+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
172+ }
173+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
174+ switch (num) {
175+ case 1 :
176+ return llvm::Intrinsic::
177+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
178+ case 2 :
179+ return llvm::Intrinsic::
180+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
181+ case 4 :
182+ return llvm::Intrinsic::
183+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
184+ }
185+ }
186+ } else if (shape.getM () == 16 && shape.getN () == 16 ) {
187+ if (eltType == NVVM::LdStMatrixEltType::B8) {
188+ switch (num) {
189+ case 1 :
190+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
191+ case 2 :
192+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
193+ }
194+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
195+ switch (num) {
196+ case 1 :
197+ return llvm::Intrinsic::
198+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
199+ case 2 :
200+ return llvm::Intrinsic::
201+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
202+ }
203+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
204+ switch (num) {
205+ case 1 :
206+ return llvm::Intrinsic::
207+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
208+ case 2 :
209+ return llvm::Intrinsic::
210+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
211+ }
163212 }
164213 }
214+ llvm_unreachable (" unknown ldmatrix kind" );
165215}
166216
167217// / Return the intrinsic ID associated with stmatrix for the given paramters.
0 commit comments