Skip to content

Commit e7bcd80

Browse files
authored
[SPIRV] Use OpCopyMemory for logical SPIRV memcpy (llvm#169348)
This commit modifies the SPIRV instruction selector to emit `OpCopyMemory` instead of `OpCopyMemorySized` when generating SPIRV for logical addressing. Previously, `G_MEMCPY` was translated to `OpCopyMemorySized`, which requires an explicit size operand. However, for logical SPIRV, the size of the pointee type is implicitly known. This change ensures that `OpCopyMemory` is used, which is more appropriate for logical SPIRV and aligns with the SPIR-V specification for logical addressing.
1 parent 6e983e3 commit e7bcd80

File tree

2 files changed

+129
-44
lines changed

2 files changed

+129
-44
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 97 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
151151
bool selectStackRestore(MachineInstr &I) const;
152152

153153
bool selectMemOperation(Register ResVReg, MachineInstr &I) const;
154+
Register getOrCreateMemSetGlobal(MachineInstr &I) const;
155+
bool selectCopyMemory(MachineInstr &I, Register SrcReg) const;
156+
bool selectCopyMemorySized(MachineInstr &I, Register SrcReg) const;
154157

155158
bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType,
156159
MachineInstr &I, unsigned NewOpcode,
@@ -1623,50 +1626,79 @@ bool SPIRVInstructionSelector::selectStackRestore(MachineInstr &I) const {
16231626
.constrainAllUses(TII, TRI, RBI);
16241627
}
16251628

1626-
bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
1627-
MachineInstr &I) const {
1629+
Register
1630+
SPIRVInstructionSelector::getOrCreateMemSetGlobal(MachineInstr &I) const {
1631+
MachineIRBuilder MIRBuilder(I);
1632+
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
1633+
1634+
// TODO: check if we have such GV, add init, use buildGlobalVariable.
1635+
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
1636+
Function &CurFunction = GR.CurMF->getFunction();
1637+
Type *LLVMArrTy =
1638+
ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
1639+
GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
1640+
true, GlobalValue::InternalLinkage,
1641+
Constant::getNullValue(LLVMArrTy));
1642+
1643+
Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
1644+
Type *ArrTy = ArrayType::get(ValTy, Num);
1645+
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
1646+
ArrTy, MIRBuilder, SPIRV::StorageClass::UniformConstant);
1647+
1648+
SPIRVType *SpvArrTy = GR.getOrCreateSPIRVType(
1649+
ArrTy, MIRBuilder, SPIRV::AccessQualifier::None, false);
1650+
1651+
unsigned Val = getIConstVal(I.getOperand(1).getReg(), MRI);
1652+
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, SpvArrTy, TII);
1653+
1654+
Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
1655+
auto MIBVar =
1656+
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable))
1657+
.addDef(VarReg)
1658+
.addUse(GR.getSPIRVTypeID(VarTy))
1659+
.addImm(SPIRV::StorageClass::UniformConstant)
1660+
.addUse(Const);
1661+
if (!MIBVar.constrainAllUses(TII, TRI, RBI))
1662+
return Register();
1663+
1664+
GR.add(GV, MIBVar);
1665+
GR.addGlobalObject(GV, GR.CurMF, VarReg);
1666+
1667+
buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
1668+
return VarReg;
1669+
}
1670+
1671+
bool SPIRVInstructionSelector::selectCopyMemory(MachineInstr &I,
1672+
Register SrcReg) const {
16281673
MachineBasicBlock &BB = *I.getParent();
1629-
Register SrcReg = I.getOperand(1).getReg();
1630-
bool Result = true;
1631-
if (I.getOpcode() == TargetOpcode::G_MEMSET) {
1674+
Register DstReg = I.getOperand(0).getReg();
1675+
SPIRVType *DstTy = GR.getSPIRVTypeForVReg(DstReg);
1676+
SPIRVType *SrcTy = GR.getSPIRVTypeForVReg(SrcReg);
1677+
if (GR.getPointeeType(DstTy) != GR.getPointeeType(SrcTy))
1678+
report_fatal_error("OpCopyMemory requires operands to have the same type");
1679+
uint64_t CopySize = getIConstVal(I.getOperand(2).getReg(), MRI);
1680+
SPIRVType *PointeeTy = GR.getPointeeType(DstTy);
1681+
const Type *LLVMPointeeTy = GR.getTypeForSPIRVType(PointeeTy);
1682+
if (!LLVMPointeeTy)
1683+
report_fatal_error(
1684+
"Unable to determine pointee type size for OpCopyMemory");
1685+
const DataLayout &DL = I.getMF()->getFunction().getDataLayout();
1686+
if (CopySize != DL.getTypeStoreSize(const_cast<Type *>(LLVMPointeeTy)))
1687+
report_fatal_error(
1688+
"OpCopyMemory requires the size to match the pointee type size");
1689+
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemory))
1690+
.addUse(DstReg)
1691+
.addUse(SrcReg);
1692+
if (I.getNumMemOperands()) {
16321693
MachineIRBuilder MIRBuilder(I);
1633-
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
1634-
unsigned Val = getIConstVal(I.getOperand(1).getReg(), MRI);
1635-
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
1636-
Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
1637-
Type *ArrTy = ArrayType::get(ValTy, Num);
1638-
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
1639-
ArrTy, MIRBuilder, SPIRV::StorageClass::UniformConstant);
1640-
1641-
SPIRVType *SpvArrTy = GR.getOrCreateSPIRVType(
1642-
ArrTy, MIRBuilder, SPIRV::AccessQualifier::None, false);
1643-
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, SpvArrTy, TII);
1644-
// TODO: check if we have such GV, add init, use buildGlobalVariable.
1645-
Function &CurFunction = GR.CurMF->getFunction();
1646-
Type *LLVMArrTy =
1647-
ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
1648-
// Module takes ownership of the global var.
1649-
GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
1650-
true, GlobalValue::InternalLinkage,
1651-
Constant::getNullValue(LLVMArrTy));
1652-
Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
1653-
auto MIBVar =
1654-
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable))
1655-
.addDef(VarReg)
1656-
.addUse(GR.getSPIRVTypeID(VarTy))
1657-
.addImm(SPIRV::StorageClass::UniformConstant)
1658-
.addUse(Const);
1659-
Result &= MIBVar.constrainAllUses(TII, TRI, RBI);
1660-
1661-
GR.add(GV, MIBVar);
1662-
GR.addGlobalObject(GV, GR.CurMF, VarReg);
1663-
1664-
buildOpDecorate(VarReg, I, TII, SPIRV::Decoration::Constant, {});
1665-
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
1666-
ValTy, I, SPIRV::StorageClass::UniformConstant);
1667-
SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
1668-
selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
1694+
addMemoryOperands(*I.memoperands_begin(), MIB, MIRBuilder, GR);
16691695
}
1696+
return MIB.constrainAllUses(TII, TRI, RBI);
1697+
}
1698+
1699+
bool SPIRVInstructionSelector::selectCopyMemorySized(MachineInstr &I,
1700+
Register SrcReg) const {
1701+
MachineBasicBlock &BB = *I.getParent();
16701702
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCopyMemorySized))
16711703
.addUse(I.getOperand(0).getReg())
16721704
.addUse(SrcReg)
@@ -1675,9 +1707,30 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
16751707
MachineIRBuilder MIRBuilder(I);
16761708
addMemoryOperands(*I.memoperands_begin(), MIB, MIRBuilder, GR);
16771709
}
1678-
Result &= MIB.constrainAllUses(TII, TRI, RBI);
1679-
if (ResVReg.isValid() && ResVReg != MIB->getOperand(0).getReg())
1680-
Result &= BuildCOPY(ResVReg, MIB->getOperand(0).getReg(), I);
1710+
return MIB.constrainAllUses(TII, TRI, RBI);
1711+
}
1712+
1713+
bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
1714+
MachineInstr &I) const {
1715+
Register SrcReg = I.getOperand(1).getReg();
1716+
bool Result = true;
1717+
if (I.getOpcode() == TargetOpcode::G_MEMSET) {
1718+
Register VarReg = getOrCreateMemSetGlobal(I);
1719+
if (!VarReg.isValid())
1720+
return false;
1721+
Type *ValTy = Type::getInt8Ty(I.getMF()->getFunction().getContext());
1722+
SPIRVType *SourceTy = GR.getOrCreateSPIRVPointerType(
1723+
ValTy, I, SPIRV::StorageClass::UniformConstant);
1724+
SrcReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
1725+
Result &= selectOpWithSrcs(SrcReg, SourceTy, I, {VarReg}, SPIRV::OpBitcast);
1726+
}
1727+
if (STI.isLogicalSPIRV()) {
1728+
Result &= selectCopyMemory(I, SrcReg);
1729+
} else {
1730+
Result &= selectCopyMemorySized(I, SrcReg);
1731+
}
1732+
if (ResVReg.isValid() && ResVReg != I.getOperand(0).getReg())
1733+
Result &= BuildCOPY(ResVReg, I.getOperand(0).getReg(), I);
16811734
return Result;
16821735
}
16831736

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK: OpName %[[dst_var:[0-9]+]] "dst"
5+
; CHECK: OpName %[[src_var:[0-9]+]] "src"
6+
7+
; CHECK: %[[f32:[0-9]+]] = OpTypeFloat 32
8+
; CHECK: %[[structS:[0-9]+]] = OpTypeStruct %[[f32]] %[[f32]] %[[f32]] %[[f32]] %[[f32]]
9+
; CHECK: %[[ptr_crosswkgrp_structS:[0-9]+]] = OpTypePointer CrossWorkgroup %[[structS]]
10+
%struct.S = type <{ float, float, float, float, float }>
11+
12+
; CHECK-DAG: %[[src_var]] = OpVariable %[[ptr_crosswkgrp_structS]] CrossWorkgroup
13+
@src = external dso_local addrspace(1) global %struct.S, align 4
14+
15+
; CHECK-DAG: %[[dst_var]] = OpVariable %[[ptr_crosswkgrp_structS]] CrossWorkgroup
16+
@dst = external dso_local addrspace(1) global %struct.S, align 4
17+
18+
; CHECK: %[[main_func:[0-9]+]] = OpFunction %{{[0-9]+}} None %{{[0-9]+}}
19+
; CHECK: %[[entry:[0-9]+]] = OpLabel
20+
; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, inaccessiblemem: none, target_mem0: none, target_mem1: none)
21+
define void @main() local_unnamed_addr #0 {
22+
entry:
23+
; CHECK: OpCopyMemory %[[dst_var]] %[[src_var]] Aligned 4
24+
call void @llvm.memcpy.p0.p0.i64(ptr addrspace(1) align 4 @dst, ptr addrspace(1) align 4 @src, i64 20, i1 false)
25+
ret void
26+
; CHECK: OpReturn
27+
; CHECK: OpFunctionEnd
28+
}
29+
30+
attributes #0 = { "hlsl.numthreads"="8,1,1" "hlsl.shader"="compute" }
31+
32+

0 commit comments

Comments
 (0)