@@ -329,11 +329,64 @@ void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder,
329
329
builder.create <fir::CallOp>(loc, eoshiftFunc, args);
330
330
}
331
331
332
+ // / Define ForcedMatmul<ACAT><AKIND><BCAT><BKIND> models.
333
+ struct ForcedMatmulTypeModel {
334
+ static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel () {
335
+ return [](mlir::MLIRContext *ctx) {
336
+ auto boxRefTy =
337
+ fir::runtime::getModel<Fortran::runtime::Descriptor &>()(ctx);
338
+ auto boxTy =
339
+ fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
340
+ auto strTy = fir::runtime::getModel<const char *>()(ctx);
341
+ auto intTy = fir::runtime::getModel<int >()(ctx);
342
+ auto voidTy = fir::runtime::getModel<void >()(ctx);
343
+ return mlir::FunctionType::get (
344
+ ctx, {boxRefTy, boxTy, boxTy, strTy, intTy}, {voidTy});
345
+ };
346
+ }
347
+ };
348
+
349
+ #define MATMUL_INSTANCE (ACAT, AKIND, BCAT, BKIND ) \
350
+ struct ForcedMatmul ##ACAT##AKIND##BCAT##BKIND \
351
+ : public ForcedMatmulTypeModel { \
352
+ static constexpr const char *name = \
353
+ ExpandAndQuoteKey (RTNAME(Matmul##ACAT##AKIND##BCAT##BKIND)); \
354
+ };
355
+
356
+ #define MATMUL_DIRECT_INSTANCE (ACAT, AKIND, BCAT, BKIND )
357
+ #define MATMUL_FORCE_ALL_TYPES 1
358
+
359
+ #include " flang/Runtime/matmul-instances.inc"
360
+
332
361
// / Generate call to Matmul intrinsic runtime routine.
333
362
void fir::runtime::genMatmul (fir::FirOpBuilder &builder, mlir::Location loc,
334
363
mlir::Value resultBox, mlir::Value matrixABox,
335
364
mlir::Value matrixBBox) {
336
- auto func = fir::runtime::getRuntimeFunc<mkRTKey (Matmul)>(loc, builder);
365
+ mlir::func::FuncOp func;
366
+ auto boxATy = matrixABox.getType ();
367
+ auto arrATy = fir::dyn_cast_ptrOrBoxEleTy (boxATy);
368
+ auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy ();
369
+ auto [aCat, aKind] = fir::mlirTypeToCategoryKind (loc, arrAEleTy);
370
+ auto boxBTy = matrixBBox.getType ();
371
+ auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy (boxBTy);
372
+ auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy ();
373
+ auto [bCat, bKind] = fir::mlirTypeToCategoryKind (loc, arrBEleTy);
374
+
375
+ #define MATMUL_INSTANCE (ACAT, AKIND, BCAT, BKIND ) \
376
+ if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
377
+ bCat == TypeCategory::BCAT && bKind == BKIND) { \
378
+ func = \
379
+ fir::runtime::getRuntimeFunc<ForcedMatmul##ACAT##AKIND##BCAT##BKIND>( \
380
+ loc, builder); \
381
+ }
382
+
383
+ #define MATMUL_DIRECT_INSTANCE (ACAT, AKIND, BCAT, BKIND )
384
+ #define MATMUL_FORCE_ALL_TYPES 1
385
+ #include " flang/Runtime/matmul-instances.inc"
386
+
387
+ if (!func) {
388
+ fir::intrinsicTypeTODO2 (builder, arrAEleTy, arrBEleTy, loc, " MATMUL" );
389
+ }
337
390
auto fTy = func.getFunctionType ();
338
391
auto sourceFile = fir::factory::locationToFilename (builder, loc);
339
392
auto sourceLine =
@@ -344,13 +397,48 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
344
397
builder.create <fir::CallOp>(loc, func, args);
345
398
}
346
399
347
- // / Generate call to MatmulTranspose intrinsic runtime routine.
400
+ // / Define ForcedMatmulTranspose<ACAT><AKIND><BCAT><BKIND> models.
401
+ #define MATMUL_INSTANCE (ACAT, AKIND, BCAT, BKIND ) \
402
+ struct ForcedMatmulTranspose ##ACAT##AKIND##BCAT##BKIND \
403
+ : public ForcedMatmulTypeModel { \
404
+ static constexpr const char *name = \
405
+ ExpandAndQuoteKey (RTNAME(MatmulTranspose##ACAT##AKIND##BCAT##BKIND)); \
406
+ };
407
+
408
+ #define MATMUL_DIRECT_INSTANCE (ACAT, AKIND, BCAT, BKIND )
409
+ #define MATMUL_FORCE_ALL_TYPES 1
410
+
411
+ #include " flang/Runtime/matmul-instances.inc"
412
+
348
413
void fir::runtime::genMatmulTranspose (fir::FirOpBuilder &builder,
349
414
mlir::Location loc, mlir::Value resultBox,
350
415
mlir::Value matrixABox,
351
416
mlir::Value matrixBBox) {
352
- auto func =
353
- fir::runtime::getRuntimeFunc<mkRTKey (MatmulTranspose)>(loc, builder);
417
+ mlir::func::FuncOp func;
418
+ auto boxATy = matrixABox.getType ();
419
+ auto arrATy = fir::dyn_cast_ptrOrBoxEleTy (boxATy);
420
+ auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy ();
421
+ auto [aCat, aKind] = fir::mlirTypeToCategoryKind (loc, arrAEleTy);
422
+ auto boxBTy = matrixBBox.getType ();
423
+ auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy (boxBTy);
424
+ auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy ();
425
+ auto [bCat, bKind] = fir::mlirTypeToCategoryKind (loc, arrBEleTy);
426
+
427
+ #define MATMUL_INSTANCE (ACAT, AKIND, BCAT, BKIND ) \
428
+ if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
429
+ bCat == TypeCategory::BCAT && bKind == BKIND) { \
430
+ func = fir::runtime::getRuntimeFunc< \
431
+ ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND>(loc, builder); \
432
+ }
433
+
434
+ #define MATMUL_DIRECT_INSTANCE (ACAT, AKIND, BCAT, BKIND )
435
+ #define MATMUL_FORCE_ALL_TYPES 1
436
+ #include " flang/Runtime/matmul-instances.inc"
437
+
438
+ if (!func) {
439
+ fir::intrinsicTypeTODO2 (builder, arrAEleTy, arrBEleTy, loc,
440
+ " MATMUL-TRANSPOSE" );
441
+ }
354
442
auto fTy = func.getFunctionType ();
355
443
auto sourceFile = fir::factory::locationToFilename (builder, loc);
356
444
auto sourceLine =
0 commit comments