Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds the ability to specify the maximum error for floating-point operations.
* - ``SPV_INTEL_ternary_bitwise_function``
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.

To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:

Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1161,9 +1161,15 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
if (Call->isSpirvOp()) {
if (GroupBuiltin->NoGroupOperation)
if (GroupBuiltin->NoGroupOperation) {
SmallVector<uint32_t, 1> ImmArgs;
if (GroupBuiltin->Opcode ==
SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL &&
Call->Arguments.size() > 4)
ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[4], MRI));
return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
GR->getSPIRVTypeID(Call->ReturnType));
GR->getSPIRVTypeID(Call->ReturnType), ImmArgs);
}

// Group Operation is a literal
Register GroupOpReg = Call->Arguments[1];
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ class GroupBuiltin<string name, Op operation> {
bit NoGroupOperation = !or(IsElect, IsAllOrAny, IsAllEqual,
IsBallot, IsInverseBallot,
IsBallotBitExtract, IsBallotFindBit,
!eq(operation, OpSubgroupMatrixMultiplyAccumulateINTEL),
!eq(operation, OpGroupNonUniformShuffle),
!eq(operation, OpGroupNonUniformShuffleXor),
!eq(operation, OpGroupNonUniformShuffleUp),
Expand Down Expand Up @@ -847,6 +848,9 @@ defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformBallotFindLSB", 2, 2
defm : DemangledGroupBuiltin<"group_ballot_find_msb", OnlySub, OpGroupNonUniformBallotFindMSB>;
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformBallotFindMSB", 2, 2, OpGroupNonUniformBallotFindMSB>;

// SPV_INTEL_subgroup_matrix_multiply_accumulate
defm : DemangledGroupBuiltinWrapper<"__spirv_SubgroupMatrixMultiplyAccumulateINTEL", 4, 5, OpSubgroupMatrixMultiplyAccumulateINTEL>;

// cl_khr_subgroup_shuffle
defm : DemangledGroupBuiltin<"group_shuffle", OnlySub, OpGroupNonUniformShuffle>;
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformShuffle", 3, 3, OpGroupNonUniformShuffle>;
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
{"SPV_INTEL_fp_max_error",
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
{"SPV_INTEL_subgroup_matrix_multiply_accumulate",
SPIRV::Extension::Extension::
SPV_INTEL_subgroup_matrix_multiply_accumulate},
{"SPV_INTEL_ternary_bitwise_function",
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function}};

Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,10 @@ def OpGroupFMax: OpGroup<"FMax", 269>;
def OpGroupUMax: OpGroup<"UMax", 270>;
def OpGroupSMax: OpGroup<"SMax", 271>;

def OpSubgroupMatrixMultiplyAccumulateINTEL: Op<6237, (outs ID:$res),
(ins TYPE:$ty, ID:$KDim, ID:$A, ID:$B, ID:$C, variable_ops),
"$res = OpSubgroupMatrixMultiplyAccumulateINTEL $ty $KDim $A $B $C">;

// TODO: 3.42.22. Device-Side Enqueue Instructions
def OpEnqueueKernel: Op<292, (outs ID:$res), (ins TYPE:$type, ID:$queue, ID:$flags, ID:$NDR, ID:$nevents, ID:$wevents,
ID:$revent, ID:$invoke, ID:$param, ID:$psize, ID:$palign, variable_ops),
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,20 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
break;
}
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
if (!ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
report_fatal_error(
"OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
"following SPIR-V "
"extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
false);
Reqs.addExtension(
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
Reqs.addCapability(
SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
break;
}
case SPIRV::OpBitwiseFunctionINTEL: {
if (!ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ defm SPV_INTEL_long_composites : ExtensionOperand<117>;
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -515,6 +516,7 @@ defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_ima
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
defm TernaryBitwiseFunctionINTEL : CapabilityOperand<6241, 0, 0, [SPV_INTEL_ternary_bitwise_function], []>;
defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
Loading
Loading