diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index aaab0a4474..a2c7887121 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -50,6 +50,10 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_ENABLE_INSTR_SCHED", "TRITON_INTEL_FAST_MATH", "TRITON_INTEL_REDUCE_TRANSPOSE", + "TRITON_INTEL_ENABLE_SIMD_REDUCE", + "TRITON_INTEL_ENHANCED_ACCELERATION_MATMUL", + "TRITON_INTEL_ENABLE_DPAS_WARP_SIZE_32", + "TRITONGEN_FORCE_GENISA", // clang-format on }; diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td index 38662d4245..8ed4dfea09 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td @@ -55,9 +55,9 @@ def TritonGEN_PrecisionTypeAttr : I32EnumAttr<"PrecisionType", I32EnumAttrCase<"S4", 5, "i4">, I32EnumAttrCase<"S2", 6, "i2">, I32EnumAttrCase<"BF8", 7, "bf8">, - I32EnumAttrCase<"TF32", 8, "tf32">, - I32EnumAttrCase<"BF16", 9, "bf16">, - I32EnumAttrCase<"FP16", 10, "f16"> + I32EnumAttrCase<"TF32", 10, "tf32">, + I32EnumAttrCase<"BF16", 11, "bf16">, + I32EnumAttrCase<"FP16", 12, "f16"> ]> { let cppNamespace = "::mlir::triton::TritonGEN"; } diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index b2bf17539e..798326132c 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -85,12 +85,15 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { return this->emitOpError( "1st operand (C) and result (D) should have the same type"); - if (CTy.getNumElements() != getRc() || DTy.getNumElements() != getRc()) + auto useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA"); + + if (!useGenISA && + (CTy.getNumElements() != getRc() || DTy.getNumElements() != getRc())) return this->emitOpError("the dimension for 1st operand (C) and " "result (D) should match repeat count"); constexpr unsigned SD = 8; - if (BTy.getNumElements() != SD) + if (!useGenISA && BTy.getNumElements() != SD) return this->emitOpError("the dimension for the 3rd operand (B) should " "match the systolic depth of 8"); @@ -141,7 +144,7 @@ LogicalResult TritonGEN::MatrixDPASOp::verify() { case TritonGEN::PrecisionType::FP16: case TritonGEN::PrecisionType::U8: case TritonGEN::PrecisionType::S8: - if (ATy.getNumElements() != getRc()) + if (!useGenISA && ATy.getNumElements() != getRc()) return this->emitOpError("2nd operand (A) should have the same number of " "elements as repeat count"); if (!AElemTy.isInteger(16)) @@ -303,6 +306,9 @@ LogicalResult TritonGEN::Matrix2DBlockLoadOp::verify() { if (verify2DBlockLoadHWRestriction(*this).failed()) return failure(); + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) + return success(); + if (verifyMatrixInput(*this).failed()) return failure(); @@ -367,6 +373,9 @@ LogicalResult TritonGEN::Matrix2DBlockStoreOp::verify() { if (verify2DBlockStoreHWRestriction(*this).failed()) return failure(); + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) + return success(); + if (verifyMatrixInput(*this).failed()) return failure(); diff --git a/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt index 555582643b..07d0feb610 100644 --- a/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonGENToLLVM/CMakeLists.txt @@ -1,12 +1,21 @@ +add_library(GenISAIntrinsics STATIC IMPORTED GLOBAL) +set_target_properties(GenISAIntrinsics + PROPERTIES IMPORTED_LOCATION ${CMAKE_CURRENT_SOURCE_DIR}/libGenISAIntrinsics.a +) + add_triton_library(TritonGENToLLVM Attributes.cpp TritonGENToLLVMPass.cpp + GenIntrinsicHelper.cpp DEPENDS TritonGENToLLVMConversionPassIncGen + GenISAIntrinsics LINK_LIBS PUBLIC + GenISAIntrinsics MLIRLLVMDialect MLIRSPIRVDialect TritonIntelUtils + MLIRTargetLLVMIRImport ) diff --git a/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicEnum.h b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicEnum.h new file mode 100644 index 0000000000..7f95456a57 --- /dev/null +++ b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicEnum.h @@ -0,0 +1,434 @@ +/*========================== begin_copyright_notice ============================ + +Copyright (C) 2023 Intel Corporation + +SPDX-License-Identifier: MIT + +============================= end_copyright_notice ===========================*/ +#pragma once + +#include "llvm/IR/Intrinsics.h" + +#include + +namespace llvm { + +namespace GenISAIntrinsic { + +enum ID : uint32_t { + no_intrinsic = llvm::Intrinsic::num_intrinsics, + GenISA_2fto2bf, + GenISA_assume_uniform, + GenISA_bftof, + GenISA_CatchAllDebugLine, + GenISA_DCL_DSCntrlPtInputVec, + GenISA_DCL_DSInputTessFactor, + GenISA_DCL_DSPatchConstInputVec, + GenISA_DCL_GSinputVec, + GenISA_DCL_GSsystemValue, + GenISA_DCL_HSControlPointID, + GenISA_DCL_HSOutputCntrlPtInputVec, + GenISA_DCL_HSPatchConstInputVec, + GenISA_DCL_HSinputVec, + GenISA_DCL_ShaderInputVec, + GenISA_DCL_SystemValue, + GenISA_DCL_input, + GenISA_DCL_inputVec, + GenISA_dpas, + GenISA_EmitHitAttributes, + GenISA_EndPrimitive, + GenISA_ftobf, + GenISA_GetBufferPtr, + GenISA_GetImplicitBufferPtr, + GenISA_GetLocalIdBufferPtr, + GenISA_GetPixelMask, + GenISA_GradientX, + GenISA_GradientXfine, + GenISA_GradientY, + GenISA_GradientYfine, + GenISA_GsCutControlHeader, + GenISA_GsStreamHeader, + GenISA_HSURBPatchHeaderRead, + GenISA_IEEE_Divide, + GenISA_IEEE_Sqrt, + GenISA_InitDiscardMask, + GenISA_InnerScalarTessFactors, + GenISA_Interpolant, + GenISA_Interpolate, + GenISA_Interpolate2, + GenISA_IsHelperInvocation, + GenISA_MediaBlockRead, + GenISA_MediaBlockRectangleRead, + GenISA_MediaBlockWrite, + GenISA_OUTPUT, + GenISA_OUTPUTGS, + GenISA_OuterScalarTessFactors, + GenISA_OutputTessControlPoint, + GenISA_PHASE_INPUT, + GenISA_PHASE_INPUTVEC, + GenISA_PHASE_OUTPUT, + GenISA_PHASE_OUTPUTVEC, + GenISA_PatchConstantOutput, + GenISA_PixelPositionX, + GenISA_PixelPositionY, + GenISA_PullCentroidBarys, + GenISA_PullSampleIndexBarys, + GenISA_PullSnappedBarys, + GenISA_QuadPrefix, + GenISA_ROUNDNE, + GenISA_RTDualBlendSource, + GenISA_RTWrite, + GenISA_ReadFromReservedArgSpace, + GenISA_RenderTargetRead, + GenISA_RenderTargetReadSampleFreq, + GenISA_RuntimeValue, + GenISA_SampleOffsetX, + GenISA_SampleOffsetY, + GenISA_SaveInReservedArgSpace, + GenISA_SetStackCallsBaseAddress, + GenISA_SetImplicitBufferPtr, + GenISA_SetDebugReg, + GenISA_SetLocalIdBufferPtr, + GenISA_SetStream, + GenISA_StackAlloca, + GenISA_VLAStackAlloca, + GenISA_UnmaskedRegionBegin, + GenISA_UnmaskedRegionEnd, + GenISA_URBRead, + GenISA_URBReadOutput, + GenISA_URBWrite, + GenISA_UpdateDiscardMask, + GenISA_WaveAll, + GenISA_WaveBallot, + GenISA_WaveClustered, + GenISA_WaveInverseBallot, + GenISA_WavePrefix, + GenISA_WaveShuffleIndex, + GenISA_WaveBroadcast, + GenISA_WorkGroupAny, + GenISA_add_pair, + GenISA_add_rtz, + GenISA_atomiccounterinc, + GenISA_atomiccounterpredec, + GenISA_bfi, + GenISA_bfrev, + GenISA_broadcastMessagePhase, + GenISA_broadcastMessagePhaseV, + GenISA_cmpSADs, + GenISA_cmpxchgatomicstructured, + GenISA_createMessagePhases, + GenISA_createMessagePhasesNoInit, + GenISA_createMessagePhasesNoInitV, + GenISA_createMessagePhasesV, + GenISA_cycleCounter, + GenISA_discard, + GenISA_dp4a_ss, + GenISA_dp4a_su, + GenISA_dp4a_us, + GenISA_dp4a_uu, + GenISA_dummyInst, + GenISA_dummyInstID, + GenISA_launder, + GenISA_dwordatomicstructured, + GenISA_eu_id, + GenISA_eu_thread_id, + GenISA_eu_thread_pause, + GenISA_evaluateSampler, + GenISA_extractMVAndSAD, + GenISA_f32tof16_rtz, + GenISA_fcmpxchgatomicraw, + GenISA_fcmpxchgatomicrawA64, + GenISA_fcmpxchgatomicstructured, + GenISA_firstbitHi, + GenISA_firstbitLo, + GenISA_firstbitShi, + GenISA_floatatomicraw, + GenISA_floatatomicrawA64, + GenISA_floatatomicstructured, + GenISA_flushsampler, + GenISA_fma_rtz, + GenISA_fma_rtp, + GenISA_fma_rtn, + GenISA_fsat, + GenISA_usat, + GenISA_isat, + GenISA_ftof_rte, + GenISA_ftof_rtn, + GenISA_ftof_rtp, + GenISA_ftof_rtz, + GenISA_ftoi_rte, + GenISA_ftoi_rtn, + GenISA_ftoi_rtp, + GenISA_ftoui_rte, + GenISA_ftoui_rtn, + GenISA_ftoui_rtp, + GenISA_sampleMlodptr, + GenISA_sampleCMlodptr, + GenISA_sampleBCMlodptr, + GenISA_sampleDCMlodptr, + GenISA_samplePOptr, + GenISA_samplePOBptr, + GenISA_samplePOLptr, + GenISA_samplePOCptr, + GenISA_samplePODptr, + GenISA_gather4Iptr, + GenISA_gather4Bptr, + GenISA_gather4Lptr, + GenISA_samplePOLCptr, + GenISA_gather4ICptr, + GenISA_gather4LCptr, + GenISA_gather4POPackedptr, + GenISA_gather4POPackedLptr, + GenISA_gather4POPackedBptr, + GenISA_gather4POPackedIptr, + GenISA_gather4POPackedICptr, + GenISA_gather4POPackedLCptr, + GenISA_gather4POPackedCptr, + GenISA_gather4IPOptr, + GenISA_gather4BPOptr, + GenISA_gather4LPOptr, + GenISA_gather4ICPOptr, + GenISA_gather4LCPOptr, + GenISA_gather4Cptr, + GenISA_gather4POCptr, + GenISA_gather4POptr, + GenISA_gather4ptr, + GenISA_getMessagePhase, + GenISA_getMessagePhaseV, + GenISA_getMessagePhaseX, + GenISA_getMessagePhaseXV, + GenISA_getR0, + GenISA_getPayloadHeader, + GenISA_getWorkDim, + GenISA_getNumWorkGroups, + GenISA_getGlobalSize, + GenISA_getLocalSize, + GenISA_getEnqueuedLocalSize, + GenISA_getLocalID_X, + GenISA_getLocalID_Y, + GenISA_getLocalID_Z, + GenISA_getPrivateBase, + GenISA_getPrintfBuffer, + GenISA_getStageInGridOrigin, + GenISA_getStageInGridSize, + GenISA_getSyncBuffer, + GenISA_getRtGlobalBufferPtr, + GenISA_getStackPointer, + GenISA_getStackSizePerThread, + GenISA_getAssertBufferPtr, + GenISA_getSR0, + GenISA_getSR0_0, + GenISA_globalSync, + GenISA_hw_thread_id, + GenISA_hw_thread_id_alloca, + GenISA_ibfe, + GenISA_icmpxchgatomicraw, + GenISA_icmpxchgatomicrawA64, + GenISA_icmpxchgatomictyped, + GenISA_fcmpxchgatomictyped, + GenISA_imulH, + GenISA_intatomicraw, + GenISA_intatomicrawA64, + GenISA_intatomictyped, + GenISA_floatatomictyped, + GenISA_is_uniform, + GenISA_itof_rtn, + GenISA_itof_rtp, + GenISA_itof_rtz, + GenISA_ldmcsptr, + GenISA_ldmsptr, + GenISA_ldmsptr16bit, + GenISA_ldptr, + GenISA_ldlptr, + GenISA_ldraw_indexed, + GenISA_ldrawvector_indexed, + GenISA_ldstructured, + GenISA_lodptr, + GenISA_memoryfence, + GenISA_mov_identity, + GenISA_movcr, + GenISA_movflag, + GenISA_software_exception, + GenISA_mul_pair, + GenISA_mul_rtz, + GenISA_pair_to_ptr, + GenISA_patchInstanceId, + GenISA_ptr_to_pair, + GenISA_readsurfacetypeandformat, + GenISA_resinfoptr, + GenISA_rsq, + GenISA_sampleBCptr, + GenISA_sampleBptr, + GenISA_sampleCptr, + GenISA_sampleDCptr, + GenISA_sampleDptr, + GenISA_sampleKillPix, + GenISA_sampleLCptr, + GenISA_sampleLptr, + GenISA_sampleinfoptr, + GenISA_sampleptr, + GenISA_setMessagePhase, + GenISA_setMessagePhaseV, + GenISA_setMessagePhaseX, + GenISA_setMessagePhaseXV, + GenISA_setMessagePhaseX_legacy, + GenISA_setMessagePhase_legacy, + GenISA_simdBlockRead, + GenISA_simdBlockReadBindless, + GenISA_simdBlockWrite, + GenISA_simdBlockWriteBindless, + GenISA_simdGetMessagePhase, + GenISA_simdGetMessagePhaseV, + GenISA_simdLaneId, + GenISA_simdLaneIdReplicate, + GenISA_simdMediaBlockRead, + GenISA_simdMediaBlockWrite, + GenISA_simdMediaRegionCopy, + GenISA_simdSetMessagePhase, + GenISA_simdSetMessagePhaseV, + GenISA_simdShuffleDown, + GenISA_simdShuffleXor, + GenISA_simdSize, + GenISA_slice_id, + GenISA_source_value, + GenISA_storeraw_indexed, + GenISA_storerawvector_indexed, + GenISA_storestructured1, + GenISA_storestructured2, + GenISA_storestructured3, + GenISA_storestructured4, + GenISA_sub_group_dpas, + GenISA_sub_pair, + GenISA_subslice_id, + GenISA_logical_subslice_id, + GenISA_dual_subslice_id, + GenISA_threadgroupbarrier, + GenISA_threadgroupbarrier_signal, + GenISA_threadgroupbarrier_wait, + GenISA_typedmemoryfence, + GenISA_typedread, + GenISA_typedwrite, + GenISA_uaddc, + GenISA_uavSerializeAll, + GenISA_uavSerializeOnResID, + GenISA_ubfe, + GenISA_uitof_rtn, + GenISA_uitof_rtp, + GenISA_uitof_rtz, + GenISA_umulH, + GenISA_usubb, + GenISA_vaBoolCentroid, + GenISA_vaBoolSum, + GenISA_vaCentroid, + GenISA_vaConvolve, + GenISA_vaConvolveGRF_16x1, + GenISA_vaConvolveGRF_16x4, + GenISA_vaDilate, + GenISA_vaErode, + GenISA_vaMinMax, + GenISA_vaMinMaxFilter, + GenISA_vectorUniform, + GenISA_vmeSendFBR, + GenISA_vmeSendFBR2, + GenISA_vmeSendIME, + GenISA_vmeSendIME2, + GenISA_vmeSendSIC, + GenISA_vmeSendSIC2, + GenISA_wavebarrier, + GenISA_frc, + GenISA_staticConstantPatchValue, + GenISA_HDCCCSFastClear, + GenISA_LSC2DBlockRead, + GenISA_LSC2DBlockWrite, + GenISA_LSC2DBlockPrefetch, + GenISA_LSCAtomicFP32, + GenISA_LSCAtomicFP64, + GenISA_LSCAtomicInts, + GenISA_LSCFence, + GenISA_LSCLoad, + GenISA_LSCLoadCmask, + GenISA_LSCLoadBlock, + GenISA_LSCLoadStatus, + GenISA_LSCPrefetch, + GenISA_LSCStore, + GenISA_LSCStoreCmask, + GenISA_LSCStoreBlock, + GenISA_bf8tohf, + GenISA_tf32tof, + GenISA_HDCuncompressedwrite, + GenISA_systemmemoryfence, + GenISA_urbfence, + GenISA_threadgroupnamedbarriers_signal, + GenISA_threadgroupnamedbarriers_wait, + GenISA_hftobf8, + GenISA_ftotf32, + GenISA_srnd_hftobf8, + GenISA_srnd_ftohf, + GenISA_OutputMeshPrimitiveData, + GenISA_OutputMeshPrimitiveDataInput, + GenISA_OutputMeshSivDataInput, + GenISA_OutputMeshVertexData, + GenISA_OutputMeshVertexDataInput, + GenISA_OutputTaskData, + GenISA_OutputTaskDataInput, + GenISA_AcceptHitAndEndSearchHL, + GenISA_AllocaNumber, + GenISA_AllocateRayQuery, + GenISA_AsyncStackID, + GenISA_AsyncStackPtr, + GenISA_SyncStackPtr, + GenISA_BindlessThreadDispatch, + GenISA_CallShaderHL, + GenISA_DispatchDimensions, + GenISA_DispatchRayIndex, + GenISA_FillValue, + GenISA_GetShaderRecordPtr, + GenISA_GlobalBufferPointer, + GenISA_GlobalRootSignatureValue, + GenISA_HitKind, + GenISA_IgnoreHitHL, + GenISA_InlinedData, + GenISA_LocalBufferPointer, + GenISA_LocalRootSignatureValue, + GenISA_PayloadPtr, + GenISA_PreemptionEnable, + GenISA_PreemptionDisable, + GenISA_RayQueryCheck, + GenISA_RayQueryRelease, + GenISA_ContinuationSignpost, + GenISA_RTStatefulBTIAndOffset, + GenISA_RayInfo, + GenISA_RayTCurrent, + GenISA_ReportHitHL, + GenISA_TileXOffset, + GenISA_TileYOffset, + GenISA_SpillValue, + GenISA_StackIDRelease, + GenISA_StackSize, + GenISA_SWHotZonePtr, + GenISA_SWStackPtr, + GenISA_TraceRayAsync, + GenISA_TraceRaySync, + GenISA_TraceRaySyncProceed, + GenISA_ShadowMemoryToSyncStack, + GenISA_SyncStackToShadowMemory, + GenISA_ReadTraceRaySync, + GenISA_TraceRayAsyncHL, + GenISA_TraceRayInlineAbort, + GenISA_TraceRayInlineCandidateType, + GenISA_TraceRayInlineCommitNonOpaqueTriangleHit, + GenISA_TraceRayInlineCommitProceduralPrimitiveHit, + GenISA_TraceRayInlineCommittedStatus, + GenISA_TraceRayInlineHL, + GenISA_TraceRaySyncProceedHL, + GenISA_TraceRayInlineRayInfo, + GenISA_rt_swstack_offset, + GenISA_FPBinaryOperator, + GenISA_bitcastfromstruct, + GenISA_bitcasttostruct, + num_genisa_intrinsics +}; + +} // namespace GenISAIntrinsic + +} // namespace llvm diff --git a/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicHelper.cpp b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicHelper.cpp new file mode 100644 index 0000000000..dff436dc93 --- /dev/null +++ b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicHelper.cpp @@ -0,0 +1,147 @@ +#include "GenIntrinsicHelper.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Target/LLVMIR/ModuleImport.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +namespace mlir { +namespace triton { +namespace gpu { +namespace intel { + +// The code convert the function attribute from the original here: +// https://github.com/llvm/llvm-project/blob/e575b7cb7a64297583d6382c16ce264d9fe45d08/mlir/lib/Target/LLVMIR/ModuleImport.cpp#L1547 +// List of LLVM IR attributes that map to an explicit attribute on the MLIR +// LLVMFuncOp. +static constexpr std::array ExplicitAttributes{ + StringLiteral("aarch64_pstate_sm_enabled"), + StringLiteral("aarch64_pstate_sm_body"), + StringLiteral("aarch64_pstate_sm_compatible"), + StringLiteral("aarch64_new_za"), + StringLiteral("aarch64_preserves_za"), + StringLiteral("aarch64_in_za"), + StringLiteral("aarch64_out_za"), + StringLiteral("aarch64_inout_za"), + StringLiteral("vscale_range"), + StringLiteral("frame-pointer"), + StringLiteral("target-features"), + StringLiteral("unsafe-fp-math"), + StringLiteral("no-infs-fp-math"), + StringLiteral("no-nans-fp-math"), + StringLiteral("approx-func-fp-math"), + StringLiteral("no-signed-zeros-fp-math"), +}; + +static void processPassthroughAttrs(llvm::Function *func, + mlir::LLVM::LLVMFuncOp funcOp) { + MLIRContext *context = funcOp.getContext(); + SmallVector passthroughs; + llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes( + llvm::AttributeList::AttrIndex::FunctionIndex); + for (llvm::Attribute attr : funcAttrs) { + // Skip the memory attribute since the LLVMFuncOp has an explicit memory + // attribute. + if (attr.hasAttribute(llvm::Attribute::Memory)) + continue; + + // Skip invalid type attributes. + if (attr.isTypeAttribute()) { + emitWarning(funcOp.getLoc(), + "type attributes on a function are invalid, skipping it"); + continue; + } + + StringRef attrName; + if (attr.isStringAttribute()) + attrName = attr.getKindAsString(); + else + attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum()); + auto keyAttr = StringAttr::get(context, attrName); + + // Skip attributes that map to an explicit attribute on the LLVMFuncOp. + if (llvm::is_contained(ExplicitAttributes, attrName)) + continue; + + if (attr.isStringAttribute()) { + StringRef val = attr.getValueAsString(); + if (val.empty()) { + passthroughs.push_back(keyAttr); + continue; + } + passthroughs.push_back( + ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); + continue; + } + if (attr.isIntAttribute()) { + auto val = std::to_string(attr.getValueAsInt()); + passthroughs.push_back( + ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); + continue; + } + if (attr.isEnumAttribute()) { + passthroughs.push_back(keyAttr); + continue; + } + + llvm_unreachable("unexpected attribute kind"); + } + + if (!passthroughs.empty()) + funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs)); +} + +mlir::LLVM::LLVMFuncOp +appendOrGetGenISADeclaration(OpBuilder &builder, llvm::GenISAIntrinsic::ID id, + ArrayRef mlirTys) { + auto mlirContext = builder.getContext(); + + SmallVector llvmTys; + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = + std::make_unique("temp", llvmContext); + mlir::LLVM::TypeToLLVMIRTranslator llvmToMLIR(llvmContext); + for (mlir::Type *ty : mlirTys) { + llvmTys.push_back(llvmToMLIR.translateType(*ty)); + } + auto llvmFunc = + llvm::GenISAIntrinsic::getDeclaration(llvmModule.get(), id, llvmTys); + + auto genISAName = llvmFunc->getName(); + + auto funcName = StringAttr::get(mlirContext, genISAName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom( + builder.getBlock() + ->getParent() + ->getParentOfType(), + funcName); + + if (funcOp) + return cast(*funcOp); + + auto llvmFuncType = llvmFunc->getFunctionType(); + LLVM::TypeFromLLVMIRTranslator mlirFromLLVM(*mlirContext); + auto mlirFuncTy = mlirFromLLVM.translateType(llvmFuncType); + mlir::LLVM::LLVMFunctionType funcTy = + cast(mlirFuncTy); + + auto parent = builder.getBlock() + ->getParent() + ->getParentOfType(); + mlir::OpBuilder b(parent); + auto ret = + b.create(mlir::UnknownLoc::get(mlirContext), genISAName, + funcTy, LLVM::Linkage::External, + /*dsoLocal*/ false, LLVM::CConv::C, + /*comdat=*/SymbolRefAttr{}); + + processPassthroughAttrs(llvmFunc, ret); + + return ret; +} + +} // namespace intel +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicHelper.h b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicHelper.h new file mode 100644 index 0000000000..1cdb3e0b7d --- /dev/null +++ b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsicHelper.h @@ -0,0 +1,76 @@ +//===- GenIntrinsicHelper.h - Gen intrinsic helper ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_VCINTRINSICHELPER_H +#define TRITON_VCINTRINSICHELPER_H + +#include "GenIntrinsics.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir { +namespace triton { +namespace gpu { +namespace intel { + +mlir::LLVM::LLVMFuncOp +appendOrGetGenISADeclaration(OpBuilder &builder, llvm::GenISAIntrinsic::ID id, + ArrayRef mlirTys = {}); + +class Intrinsic { +protected: + LLVM::LLVMFuncOp intrinsicDecl; + +public: + Value operator()(OpBuilder &rewriter, Location loc, ValueRange args) { + auto funName = intrinsicDecl.getName(); + auto retType = intrinsicDecl.getResultTypes(); + auto funCall = rewriter.create(loc, retType, funName, args); + funCall.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + return funCall.getResult(); + } +}; + +template class GenISA : public Intrinsic { +public: + template + explicit GenISA(OpBuilder &builder, OverrideTypes... retTy) { + // get GenISA intrinsic declaration. + intrinsicDecl = appendOrGetGenISADeclaration(builder, INST_ID, {&retTy...}); + } + + template + LLVM::CallOp operator()(OpBuilder &rewriter, Location loc, Args... args) { + auto funName = intrinsicDecl.getName(); + auto retType = intrinsicDecl.getResultTypes(); + auto funCall = rewriter.create(loc, retType, funName, + ValueRange{args...}); + funCall.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + return funCall; + } +}; + +class GenISA_Prefetching_2D + : public GenISA { + +public: + using GenISA::GenISA; +}; + +class GenISA_Dpas + : public GenISA { + +public: + using GenISA::GenISA; +}; + +} // namespace intel +} // namespace gpu +} // namespace triton +} // namespace mlir + +#endif // TRITON_VCINTRINSICHELPER_H diff --git a/third_party/intel/lib/TritonGENToLLVM/GenIntrinsics.h b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsics.h new file mode 100644 index 0000000000..ec69779e4f --- /dev/null +++ b/third_party/intel/lib/TritonGENToLLVM/GenIntrinsics.h @@ -0,0 +1,72 @@ +/*========================== begin_copyright_notice ============================ + +Copyright (C) 2017-2021 Intel Corporation + +SPDX-License-Identifier: MIT + +============================= end_copyright_notice ===========================*/ +#pragma once + +#include "GenIntrinsicEnum.h" + +#include "llvm/IR/Function.h" + +#include +#include + +namespace llvm { + +namespace GenISAIntrinsic { + +/// Intrinsic::getName(ID) - Return the LLVM name for an intrinsic, such as +/// "llvm.ppc.altivec.lvx". +std::string getName(ID id, ArrayRef Tys = std::nullopt); + +/// Intrinsic::getType(ID) - Return the function type for an intrinsic. +/// +FunctionType *getType(LLVMContext &Context, ID id, + ArrayRef Tys = std::nullopt); + +llvm::AttributeList getGenIntrinsicAttributes(LLVMContext &Context, ID id); + +struct IntrinsicComments { + const char *funcDescription; + std::vector outputs; + std::vector inputs; +}; + +IntrinsicComments getIntrinsicComments(ID id); + +/// Intrinsic::getDeclaration(M, ID) - Create or insert an LLVM Function +/// declaration for an intrinsic, and return it. +/// +/// The OverloadedTys parameter is for intrinsics with overloaded types +/// (i.e., those using iAny, fAny, vAny, or iPTRAny). For a declaration of +/// an overloaded intrinsic, Tys must provide exactly one type for each +/// overloaded type in the intrinsic in order of dst then srcs. +/// +/// For instance, consider the following overloaded function. +/// uint2 foo(size_t offset, int bar, const __global uint2 *p); +/// uint4 foo(size_t offset, int bar, const __global uint4 *p); +/// Such a function has two overloaded type parameters: dst and src2. +/// Thus the type array should two elements: +/// Type Ts[2]{int2, int2}: to resolve to the first instance. +/// Type Ts[2]{int4, int4}: to resolve to the second. +#if defined(ANDROID) || defined(__linux__) +__attribute__((visibility("default"))) Function * +getDeclaration(Module *M, ID id, ArrayRef OverloadedTys = std::nullopt); +#else +Function *getDeclaration(Module *M, ID id, + ArrayRef OverloadedTys = None); +#endif + +// Override of isIntrinsic method defined in Function.h +inline const char *getGenIntrinsicPrefix() { return "llvm.genx."; } +inline bool isIntrinsic(const Function *CF) { + return (CF->getName().starts_with(getGenIntrinsicPrefix())); +} +ID getIntrinsicID(const Function *F, bool useContextWrapper = true); + +} // namespace GenISAIntrinsic + +} // namespace llvm diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index 6740a47e78..97f39769f6 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -40,6 +40,10 @@ #include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h" #include "intel/include/TritonGENToSPIRV/TritonGENToSPIRVPass.h" +#include + +#include "GenIntrinsicHelper.h" + namespace mlir::triton { #define GEN_PASS_DEF_CONVERTTRITONGENTOLLVM #include "intel/include/TritonGENToLLVM/Passes.h.inc" @@ -431,27 +435,48 @@ struct TritonMatrixDPASLowering if (cOrigTy != cTy) c = rewriter.create(loc, cTy, c); - std::string fnName = "__spirv_SubgroupMatrixMultiplyAccumulateINTEL"; - SmallVector argTypes{int32Ty, aTy, bTy, cTy, int32Ty}; - fnName = intel::mangle(fnName, argTypes); - - TritonLLVMOpBuilder builder(loc, rewriter); - Value kDim = builder.i32_val(8 /*systolic depth*/ * - getNumOperandsPerDword(precisionA)); - SmallVector args{ - kDim, a, b, c, - builder.i32_val(getMatrixMultiplyAccumulateOperandsVal( - cOrigTy.getElementType(), precisionA))}; - auto memAttr = rewriter.getAttr( - /*other=*/LLVM::ModRefInfo::NoModRef, - /*argMem=*/LLVM::ModRefInfo::NoModRef, - /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); - auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs; - funcAttrs.memEffectsAttr = memAttr; + Value result; + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) { + MLIRContext *ctx = rewriter.getContext(); + auto builder = TritonLLVMOpBuilder(loc, rewriter); + mlir::triton::gpu::intel::GenISA_Dpas dpasOp(rewriter, cTy, cTy, aTy, + bTy); + + // refer the call signature in GenISA + result = + dpasOp(rewriter, loc, c, a, b, + builder.i32_val( + static_cast(precisionA)), /*src0's precision*/ + builder.i32_val( + static_cast(op.getPb())), /*src1's precision*/ + builder.i32_val(8), /*systolic depth*/ + builder.i32_val(8), /*repeate count*/ + builder.int_val(1, 0) /*is double = false*/) + ->getResult(0); + } else { + std::string fnName = "__spirv_SubgroupMatrixMultiplyAccumulateINTEL"; + SmallVector argTypes{int32Ty, aTy, bTy, cTy, int32Ty}; + fnName = intel::mangle(fnName, argTypes); + + TritonLLVMOpBuilder builder(loc, rewriter); + Value kDim = builder.i32_val(8 /*systolic depth*/ * + getNumOperandsPerDword(precisionA)); + SmallVector args{ + kDim, a, b, c, + builder.i32_val(getMatrixMultiplyAccumulateOperandsVal( + cOrigTy.getElementType(), precisionA))}; + auto memAttr = rewriter.getAttr( + /*other=*/LLVM::ModRefInfo::NoModRef, + /*argMem=*/LLVM::ModRefInfo::NoModRef, + /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); + auto funcAttrs = intel::convergentNoUnwindWillReturnAttrs; + funcAttrs.memEffectsAttr = memAttr; + + result = intel::createDeviceFunctionCall(rewriter, fnName, cTy, argTypes, + args, {}, funcAttrs) + ->getResult(0); + } - Value result = intel::createDeviceFunctionCall( - rewriter, fnName, cTy, argTypes, args, {}, funcAttrs) - ->getResult(0); if (cOrigTy != cTy) result = rewriter.create(loc, cOrigTy, result); @@ -508,7 +533,8 @@ struct TritonMatrix2DBlockLoadLowering LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isSPVBuiltinAvailable(op)) { + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA") || + !isSPVBuiltinAvailable(op)) { // Fallback to GenISA interface. rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter)); return success(); @@ -583,6 +609,12 @@ struct TritonMatrix2DBlockStoreLowering LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // TODO: Remove GenISA lowering after PoC productization is completed. + if (tools::getBoolEnv("TRITONGEN_FORCE_GENISA")) { + rewriter.replaceOp(op, createGenISA2DBlockWrite(op, rewriter)); + return success(); + } + MLIRContext *ctx = rewriter.getContext(); Location loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -651,6 +683,13 @@ struct TritonMatrix2DBlockPrefetchLowering LogicalResult matchAndRewrite(TritonGEN::Matrix2DBlockPrefetchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // TODO: Remove GenISA lowering after PoC productization is completed. + bool useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA"); + if (useGenISA) { + rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter)); + return success(); + } + MLIRContext *ctx = rewriter.getContext(); Location loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); diff --git a/third_party/intel/lib/TritonGENToLLVM/libGenISAIntrinsics.a b/third_party/intel/lib/TritonGENToLLVM/libGenISAIntrinsics.a new file mode 100644 index 0000000000..a7344f63db Binary files /dev/null and b/third_party/intel/lib/TritonGENToLLVM/libGenISAIntrinsics.a differ diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 7e9db10572..d31dee3e9d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1648,6 +1648,15 @@ struct LoadOpConversion usePackedType = true; } + if (isTransposeRequired) { + if (!usePackedType) { + // use the d32 transpose 2d load. + loadResultElemType = i32_ty; + packedElemsPerLanePerDPASInst = 32 / elemSizeInBits; + usePackedType = true; + } + } + Type packedDPASOperandType = LLVM::getVectorType(loadResultElemType, packedElemsPerLanePerDPASInst); @@ -2105,6 +2114,10 @@ struct LoadOpConversion rewriter.eraseOp(load2dOp); return failure(); } +#if 0 + targetInfo.printf(rewriter, "base: %p, baseWidth: %d, baseHeight:%d, pitch:%d, offset_x:%d, offset_y:%d, loadVal: %d", + {base, base_width, baseHeight, base_pitch, offsetX, offsetY, load2dOp.getResult()}); +#endif LLVM_DEBUG(llvm::dbgs() << "Generated load op: " << load2dOp << "\n"); unsigned packedRowNum = opIdx == DpasEncodingAttr::OpIdx::OperandA @@ -2166,11 +2179,14 @@ struct LoadOpConversion vblk * packedColNumPerVBlock + col) << ", " << std::to_string(k + row) << "\n"; }); + auto ret = b.bitcast(loadVal, unpackedDPASOperandType); +#if 0 + targetInfo.printf(rewriter, "loadVal: %d", {ret}); +#endif loadVals[{outer * packedColNum * numLoadPerOutRepCluster + rep * packedColNum + vblk * packedColNumPerVBlock + col, - k + row}] = - b.bitcast(loadVal, unpackedDPASOperandType); + k + row}] = ret; } break; case DpasEncodingAttr::OpIdx::OperandC: { llvm_unreachable("unexpected OpIdx::OperandC");