@@ -265,16 +265,50 @@ class OpLowerer {
265265
266266 // / Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
267267 // / Since we expect to be post-scalarization, make an effort to avoid vectors.
268- Error replaceResRetUses (CallInst *Intrin, CallInst *Op) {
268+ Error replaceResRetUses (CallInst *Intrin, CallInst *Op, bool HasCheckBit ) {
269269 IRBuilder<> &IRB = OpBuilder.getIRB ();
270270
271+ Instruction *OldResult = Intrin;
271272 Type *OldTy = Intrin->getType ();
272273
274+ if (HasCheckBit) {
275+ auto *ST = cast<StructType>(OldTy);
276+
277+ Value *CheckOp = nullptr ;
278+ Type *Int32Ty = IRB.getInt32Ty ();
279+ for (Use &U : make_early_inc_range (OldResult->uses ())) {
280+ if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser ())) {
281+ ArrayRef<unsigned > Indices = EVI->getIndices ();
282+ assert (Indices.size () == 1 );
283+ // We're only interested in uses of the check bit for now.
284+ if (Indices[0 ] != 1 )
285+ continue ;
286+ if (!CheckOp) {
287+ Value *NewEVI = IRB.CreateExtractValue (Op, 4 );
288+ Expected<CallInst *> OpCall = OpBuilder.tryCreateOp (
289+ OpCode::CheckAccessFullyMapped, {NewEVI}, Int32Ty);
290+ if (Error E = OpCall.takeError ())
291+ return E;
292+ CheckOp = *OpCall;
293+ }
294+ EVI->replaceAllUsesWith (CheckOp);
295+ EVI->eraseFromParent ();
296+ }
297+ }
298+
299+ OldResult = cast<Instruction>(IRB.CreateExtractValue (Op, 0 ));
300+ OldTy = ST->getElementType (0 );
301+ }
302+
273303 // For scalars, we just extract the first element.
274304 if (!isa<FixedVectorType>(OldTy)) {
275305 Value *EVI = IRB.CreateExtractValue (Op, 0 );
276- Intrin->replaceAllUsesWith (EVI);
277- Intrin->eraseFromParent ();
306+ OldResult->replaceAllUsesWith (EVI);
307+ OldResult->eraseFromParent ();
308+ if (OldResult != Intrin) {
309+ assert (Intrin->use_empty () && " Intrinsic still has uses?" );
310+ Intrin->eraseFromParent ();
311+ }
278312 return Error::success ();
279313 }
280314
@@ -283,7 +317,7 @@ class OpLowerer {
283317
284318 // The users of the operation should all be scalarized, so we attempt to
285319 // replace the extractelements with extractvalues directly.
286- for (Use &U : make_early_inc_range (Intrin ->uses ())) {
320+ for (Use &U : make_early_inc_range (OldResult ->uses ())) {
287321 if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser ())) {
288322 if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand ())) {
289323 size_t IndexVal = IndexOp->getZExtValue ();
@@ -331,22 +365,27 @@ class OpLowerer {
331365 // If we still have uses, then we're not fully scalarized and need to
332366 // recreate the vector. This should only happen for things like exported
333367 // functions from libraries.
334- if (!Intrin ->use_empty ()) {
368+ if (!OldResult ->use_empty ()) {
335369 for (int I = 0 , E = N; I != E; ++I)
336370 if (!Extracts[I])
337371 Extracts[I] = IRB.CreateExtractValue (Op, I);
338372
339373 Value *Vec = UndefValue::get (OldTy);
340374 for (int I = 0 , E = N; I != E; ++I)
341375 Vec = IRB.CreateInsertElement (Vec, Extracts[I], I);
342- Intrin->replaceAllUsesWith (Vec);
376+ OldResult->replaceAllUsesWith (Vec);
377+ }
378+
379+ OldResult->eraseFromParent ();
380+ if (OldResult != Intrin) {
381+ assert (Intrin->use_empty () && " Intrinsic still has uses?" );
382+ Intrin->eraseFromParent ();
343383 }
344384
345- Intrin->eraseFromParent ();
346385 return Error::success ();
347386 }
348387
349- [[nodiscard]] bool lowerTypedBufferLoad (Function &F) {
388+ [[nodiscard]] bool lowerTypedBufferLoad (Function &F, bool HasCheckBit ) {
350389 IRBuilder<> &IRB = OpBuilder.getIRB ();
351390 Type *Int32Ty = IRB.getInt32Ty ();
352391
@@ -358,14 +397,17 @@ class OpLowerer {
358397 Value *Index0 = CI->getArgOperand (1 );
359398 Value *Index1 = UndefValue::get (Int32Ty);
360399
361- Type *NewRetTy = OpBuilder.getResRetType (CI->getType ()->getScalarType ());
400+ Type *OldTy = CI->getType ();
401+ if (HasCheckBit)
402+ OldTy = cast<StructType>(OldTy)->getElementType (0 );
403+ Type *NewRetTy = OpBuilder.getResRetType (OldTy->getScalarType ());
362404
363405 std::array<Value *, 3 > Args{Handle, Index0, Index1};
364406 Expected<CallInst *> OpCall =
365407 OpBuilder.tryCreateOp (OpCode::BufferLoad, Args, NewRetTy);
366408 if (Error E = OpCall.takeError ())
367409 return E;
368- if (Error E = replaceResRetUses (CI, *OpCall))
410+ if (Error E = replaceResRetUses (CI, *OpCall, HasCheckBit ))
369411 return E;
370412
371413 return Error::success ();
@@ -434,7 +476,10 @@ class OpLowerer {
434476 HasErrors |= lowerHandleFromBinding (F);
435477 break ;
436478 case Intrinsic::dx_typedBufferLoad:
437- HasErrors |= lowerTypedBufferLoad (F);
479+ HasErrors |= lowerTypedBufferLoad (F, /* HasCheckBit=*/ false );
480+ break ;
481+ case Intrinsic::dx_typedBufferLoad_checkbit:
482+ HasErrors |= lowerTypedBufferLoad (F, /* HasCheckBit=*/ true );
438483 break ;
439484 case Intrinsic::dx_typedBufferStore:
440485 HasErrors |= lowerTypedBufferStore (F);
0 commit comments