@@ -372,18 +372,15 @@ static MachineInstr *getBlockStructInstr(Register ParamReg,
372372 // We expect the following sequence of instructions:
373373 // %0:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.alloca)
374374 // or = G_GLOBAL_VALUE @block_literal_global
375- // %1:_(pN) = G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.spv.bitcast), %0
376- // %2:_(p4) = G_ADDRSPACE_CAST %1:_(pN)
375+ // %1:_(p4) = G_ADDRSPACE_CAST %0:_(pN)
377376 MachineInstr *MI = MRI->getUniqueVRegDef (ParamReg);
378377 assert (MI->getOpcode () == TargetOpcode::G_ADDRSPACE_CAST &&
379378 MI->getOperand (1 ).isReg ());
380- Register BitcastReg = MI->getOperand (1 ).getReg ();
381- MachineInstr *BitcastMI = MRI->getUniqueVRegDef (BitcastReg);
382- assert (isSpvIntrinsic (*BitcastMI, Intrinsic::spv_bitcast) &&
383- BitcastMI->getOperand (2 ).isReg ());
384- Register ValueReg = BitcastMI->getOperand (2 ).getReg ();
385- MachineInstr *ValueMI = MRI->getUniqueVRegDef (ValueReg);
386- return ValueMI;
379+ Register PtrReg = MI->getOperand (1 ).getReg ();
380+ MachineInstr *PtrMI = MRI->getUniqueVRegDef (PtrReg);
381+ assert (PtrMI->getOpcode () == TargetOpcode::G_GLOBAL_VALUE ||
382+ isSpvIntrinsic (*PtrMI, Intrinsic::spv_alloca));
383+ return PtrMI;
387384}
388385
389386// Return an integer constant corresponding to the given register and
@@ -2509,6 +2506,59 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
25092506 return true ;
25102507}
25112508
2509+ static bool buildNDRangeSubGroup (const SPIRV::IncomingCall *Call,
2510+ unsigned Opcode, MachineIRBuilder &MIRBuilder,
2511+ SPIRVGlobalRegistry *GR) {
2512+ MachineRegisterInfo *MRI = MIRBuilder.getMRI ();
2513+ const DataLayout &DL = MIRBuilder.getDataLayout ();
2514+
2515+ auto MIB = MIRBuilder.buildInstr (Opcode)
2516+ .addDef (Call->ReturnRegister )
2517+ .addUse (GR->getSPIRVTypeID (Call->ReturnType ))
2518+ .addUse (Call->Arguments [0 ]);
2519+ unsigned int BlockFIdx = 1 ;
2520+ MachineInstr *BlockMI = getBlockStructInstr (Call->Arguments [BlockFIdx], MRI);
2521+ assert (BlockMI->getOpcode () == TargetOpcode::G_GLOBAL_VALUE);
2522+ // Invoke: Pointer to invoke function.
2523+ Register BlockFReg = BlockMI->getOperand (0 ).getReg ();
2524+ MIB.addUse (BlockFReg);
2525+ MRI->setRegClass (BlockFReg, &SPIRV::pIDRegClass);
2526+
2527+ Register BlockLiteralReg = Call->Arguments [BlockFIdx + 1 ];
2528+ // Param: Pointer to block literal.
2529+ MIB.addUse (BlockLiteralReg);
2530+ BlockMI = MRI->getUniqueVRegDef (BlockLiteralReg);
2531+ Register BlockMIReg =
2532+ stripAddrspaceCast (BlockMI->getOperand (1 ).getReg (), *MRI);
2533+ BlockMI = MRI->getUniqueVRegDef (BlockMIReg);
2534+
2535+ if (BlockMI->getOpcode () == TargetOpcode::G_GLOBAL_VALUE) {
2536+ // Size and align are given explicitly here.
2537+ const GlobalValue *GV = BlockMI->getOperand (1 ).getGlobal ();
2538+
2539+ const GlobalVariable *BlockGV = dyn_cast<GlobalVariable>(GV);
2540+ assert (BlockGV->hasInitializer () &&
2541+ " Block literal should have an initializer" );
2542+ const Constant *Init = BlockGV->getInitializer ();
2543+ const ConstantStruct *CS = dyn_cast<ConstantStruct>(Init);
2544+ // Extract fields
2545+ const ConstantInt *SizeConst = dyn_cast<ConstantInt>(CS->getOperand (0 ));
2546+ const ConstantInt *AlignConst = dyn_cast<ConstantInt>(CS->getOperand (1 ));
2547+ uint64_t BlockSize = SizeConst->getZExtValue ();
2548+ uint64_t BlockAlign = AlignConst->getZExtValue ();
2549+ MIB.addUse (buildConstantIntReg32 (BlockSize, MIRBuilder, GR));
2550+ MIB.addUse (buildConstantIntReg32 (BlockAlign, MIRBuilder, GR));
2551+ } else {
2552+ Type *PType = const_cast <Type *>(getBlockStructType (BlockLiteralReg, MRI));
2553+ // Fallback to default if not found
2554+ MIB.addUse (
2555+ buildConstantIntReg32 (DL.getTypeStoreSize (PType), MIRBuilder, GR));
2556+ MIB.addUse (buildConstantIntReg32 (DL.getPrefTypeAlign (PType).value (),
2557+ MIRBuilder, GR));
2558+ }
2559+ return true ;
2560+ }
2561+
25122562static bool generateEnqueueInst (const SPIRV::IncomingCall *Call,
25132563 MachineIRBuilder &MIRBuilder,
25142564 SPIRVGlobalRegistry *GR) {
@@ -2544,6 +2594,9 @@ static bool generateEnqueueInst(const SPIRV::IncomingCall *Call,
25442594 return buildNDRange (Call, MIRBuilder, GR);
25452595 case SPIRV::OpEnqueueKernel:
25462596 return buildEnqueueKernel (Call, MIRBuilder, GR);
2597+ case SPIRV::OpGetKernelNDrangeSubGroupCount:
2598+ case SPIRV::OpGetKernelNDrangeMaxSubGroupSize:
2599+ return buildNDRangeSubGroup (Call, Opcode, MIRBuilder, GR);
25472600 default :
25482601 return false ;
25492602 }
0 commit comments