@@ -714,6 +714,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
714714 }
715715};
716716
717+ // ===----------------------------------------------------------------------===//
718+ // GPU index id operations
719+ // ===----------------------------------------------------------------------===//
720+ /*
721+ // Launch Config ops
722+ // dimidx - x, y, z - is fixed to i32
723+ // return type is set by XeVM type converter
724+ // get_local_id
725+ xevm::WorkitemIdXOp;
726+ xevm::WorkitemIdYOp;
727+ xevm::WorkitemIdZOp;
728+ // get_local_size
729+ xevm::WorkgroupDimXOp;
730+ xevm::WorkgroupDimYOp;
731+ xevm::WorkgroupDimZOp;
732+ // get_group_id
733+ xevm::WorkgroupIdXOp;
734+ xevm::WorkgroupIdYOp;
735+ xevm::WorkgroupIdZOp;
736+ // get_num_groups
737+ xevm::GridDimXOp;
738+ xevm::GridDimYOp;
739+ xevm::GridDimZOp;
740+ // get_global_id : to be added if needed
741+ */
742+
743+ // Helpers to get the OpenCL function name and dimension argument for each op.
744+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkitemIdXOp) {
745+ return {" get_local_id" , 0 };
746+ }
747+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkitemIdYOp) {
748+ return {" get_local_id" , 1 };
749+ }
750+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkitemIdZOp) {
751+ return {" get_local_id" , 2 };
752+ }
753+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkgroupDimXOp) {
754+ return {" get_local_size" , 0 };
755+ }
756+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkgroupDimYOp) {
757+ return {" get_local_size" , 1 };
758+ }
759+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkgroupDimZOp) {
760+ return {" get_local_size" , 2 };
761+ }
762+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkgroupIdXOp) {
763+ return {" get_group_id" , 0 };
764+ }
765+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkgroupIdYOp) {
766+ return {" get_group_id" , 1 };
767+ }
768+ static std::pair<StringRef, int64_t > getConfig (xevm::WorkgroupIdZOp) {
769+ return {" get_group_id" , 2 };
770+ }
771+ static std::pair<StringRef, int64_t > getConfig (xevm::GridDimXOp) {
772+ return {" get_num_groups" , 0 };
773+ }
774+ static std::pair<StringRef, int64_t > getConfig (xevm::GridDimYOp) {
775+ return {" get_num_groups" , 1 };
776+ }
777+ static std::pair<StringRef, int64_t > getConfig (xevm::GridDimZOp) {
778+ return {" get_num_groups" , 2 };
779+ }
780+ // / Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
781+ // / a constant argument for the dimension - x, y or z.
782+ template <typename OpType>
783+ class LaunchConfigOpToOCLPattern : public OpConversionPattern <OpType> {
784+ using OpConversionPattern<OpType>::OpConversionPattern;
785+ LogicalResult
786+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
787+ ConversionPatternRewriter &rewriter) const override {
788+ Location loc = op->getLoc ();
789+ auto [baseName, dim] = getConfig (op);
790+ Type dimTy = rewriter.getI32Type ();
791+ Value dimVal = LLVM::ConstantOp::create (rewriter, loc, dimTy,
792+ static_cast <int64_t >(dim));
793+ std::string func = mangle (baseName, {dimTy}, {true });
794+ Type resTy = op.getType ();
795+ auto call =
796+ createDeviceFunctionCall (rewriter, func, resTy, {dimTy}, {dimVal}, {},
797+ noUnwindWillReturnAttrs, op.getOperation ());
798+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
799+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
800+ /* other=*/ noModRef,
801+ /* argMem=*/ noModRef, /* inaccessibleMem=*/ noModRef);
802+ call.setMemoryEffectsAttr (memAttr);
803+ rewriter.replaceOp (op, call);
804+ return success ();
805+ }
806+ };
807+
808+ /*
809+ // Subgroup ops
810+ // get_sub_group_local_id
811+ xevm::LaneIdOp;
812+ // get_sub_group_id
813+ xevm::SubgroupIdOp;
814+ // get_sub_group_size
815+ xevm::SubgroupSizeOp;
816+ // get_num_sub_groups : to be added if needed
817+ */
818+
819+ // Helpers to get the OpenCL function name for each op.
820+ static StringRef getConfig (xevm::LaneIdOp) { return " get_sub_group_local_id" ; }
821+ static StringRef getConfig (xevm::SubgroupIdOp) { return " get_sub_group_id" ; }
822+ static StringRef getConfig (xevm::SubgroupSizeOp) {
823+ return " get_sub_group_size" ;
824+ }
825+ template <typename OpType>
826+ class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern <OpType> {
827+ using OpConversionPattern<OpType>::OpConversionPattern;
828+ LogicalResult
829+ matchAndRewrite (OpType op, typename OpType::Adaptor adaptor,
830+ ConversionPatternRewriter &rewriter) const override {
831+ std::string func = mangle (getConfig (op).str (), {});
832+ Type resTy = op.getType ();
833+ auto call =
834+ createDeviceFunctionCall (rewriter, func, resTy, {}, {}, {},
835+ noUnwindWillReturnAttrs, op.getOperation ());
836+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
837+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
838+ /* other=*/ noModRef,
839+ /* argMem=*/ noModRef, /* inaccessibleMem=*/ noModRef);
840+ call.setMemoryEffectsAttr (memAttr);
841+ rewriter.replaceOp (op, call);
842+ return success ();
843+ }
844+ };
845+
717846// ===----------------------------------------------------------------------===//
718847// Pass Definition
719848// ===----------------------------------------------------------------------===//
@@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
775904 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
776905 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
777906 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
778- BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
907+ BlockLoadStore1DToOCLPattern<BlockStoreOp>,
908+ LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
909+ LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
910+ LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
911+ LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
912+ LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
913+ LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
914+ LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
915+ LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
916+ LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
917+ LaunchConfigOpToOCLPattern<GridDimXOp>,
918+ LaunchConfigOpToOCLPattern<GridDimYOp>,
919+ LaunchConfigOpToOCLPattern<GridDimZOp>,
920+ SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
921+ SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
922+ SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
779923 patterns.getContext ());
780924}
781925
0 commit comments