@@ -289,8 +289,7 @@ void CustomSafeOptPass::visitAnd(BinaryOperator& I) {
289289// also be written manually as
290290// uint32_t other_id = sg.get_local_id() ^ XOR_VALUE;
291291// r = select_from_group(sg, x, other_id);
292- void CustomSafeOptPass::visitShuffleIndex (llvm::CallInst* I)
293- {
292+ void CustomSafeOptPass::visitShuffleIndex (llvm::CallInst* I) {
294293 using namespace llvm ::PatternMatch;
295294 /*
296295 Pattern match
@@ -299,87 +298,148 @@ void CustomSafeOptPass::visitShuffleIndex(llvm::CallInst* I)
299298 %xor = xor i16 %[optional1], 1
300299 ...[optional2] = %xor
301300 %simdShuffle = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %x, i32 %[optional2], i32 0)
302-
303- Optional can be any combinations of :
301+ Optional can be any combinations of:
304302 * %and = and i16 %856, 63
305303 * %zext = zext i16 %857 to i32
306304 We ignore any combinations of those, as they don't change the final calculated value,
307305 and different permutations were observed.
308306 */
309307
308+ auto getInstructionIgnoringAndZext = [](Value* V, unsigned Opcode) -> Instruction* {
309+ while (auto * VI = dyn_cast<Instruction>(V)) {
310+ if (VI->getOpcode () == Opcode) {
311+ return VI;
312+ }
313+ else if (auto * ZI = dyn_cast<ZExtInst>(VI)) {
314+ // Check if zext is from i16 to i32
315+ if (ZI->getSrcTy ()->isIntegerTy (16 ) && ZI->getDestTy ()->isIntegerTy (32 )) {
316+ V = ZI->getOperand (0 ); // Skip over zext
317+ }
318+ else {
319+ return nullptr ; // Not the zext we are looking for
320+ }
321+ }
322+ else if (VI->getOpcode () == Instruction::And) {
323+ ConstantInt* andValueConstant = dyn_cast<ConstantInt>(VI->getOperand (1 ));
324+ // We handle "redundant values", so those which bits enable all of
325+ // 32 lanes, so 31, 63 (spotted in nature), 127, 255 etc.
326+ if (andValueConstant && ((andValueConstant->getZExtValue () & 31 ) != 31 )) {
327+ return nullptr ;
328+ }
329+ V = VI->getOperand (0 ); // Skip over and
330+ }
331+ else {
332+ return nullptr ; // Not a zext, and, or the specified opcode
333+ }
334+ }
335+ return nullptr ; // unreachable
336+ };
337+
338+ Value* indexOp = I->getOperand (1 );
339+
340+ // Get helper lanes parameter
310341 ConstantInt* enableHelperLanes = dyn_cast<ConstantInt>(I->getOperand (2 ));
311- if (!enableHelperLanes || enableHelperLanes-> getZExtValue () != 0 ) {
342+ if (!enableHelperLanes) {
312343 return ;
313344 }
314345
315- auto getInstructionIgnoringAndZext = []( Value* V, unsigned Opcode ) -> Instruction* {
316- while ( auto * VI = dyn_cast<Instruction>( V ) ) {
317- if ( VI->getOpcode () == Opcode ) {
318- return VI;
319- }
320- else if ( auto * ZI = dyn_cast<ZExtInst>( VI ) ) {
321- // Check if zext is from i16 to i32
322- if ( ZI->getSrcTy ()->isIntegerTy ( 16 ) && ZI->getDestTy ()->isIntegerTy ( 32 ) ) {
323- V = ZI->getOperand ( 0 ); // Skip over zext
324- } else {
325- return nullptr ; // Not the zext we are looking for
346+ // Try QuadBroadcast pattern if helper lanes = 1
347+ if (enableHelperLanes->getZExtValue () == 1 ) {
348+ auto * zextInst = dyn_cast<ZExtInst>(indexOp);
349+ if (zextInst && zextInst->getSrcTy ()->isIntegerTy (16 ) &&
350+ zextInst->getDestTy ()->isIntegerTy (32 )) {
351+
352+ auto * andInst = dyn_cast<Instruction>(zextInst->getOperand (0 ));
353+ if (andInst && andInst->getOpcode () == Instruction::And) {
354+ // Check for mask constant -4 (0xFFFC)
355+ auto * mask = dyn_cast<ConstantInt>(andInst->getOperand (1 ));
356+ if (mask && mask->getSExtValue () == -4 ) {
357+ uint32_t laneIdx = 0 ;
358+ Value* simdLaneOp = andInst->getOperand (0 );
359+
360+ // Check for or operation
361+ if (auto * orInst = dyn_cast<Instruction>(simdLaneOp)) {
362+ if (orInst->getOpcode () == Instruction::Or) {
363+ auto * constOffset = dyn_cast<ConstantInt>(orInst->getOperand (1 ));
364+ // Return if OR value is not a constant or is >= 4
365+ if (!constOffset || constOffset->getZExtValue () >= 4 ) {
366+ return ;
367+ }
368+ laneIdx = constOffset->getZExtValue () & 0x3 ;
369+ simdLaneOp = orInst->getOperand (0 );
370+ }
326371 }
327- }
328- else if ( VI->getOpcode () == Instruction::And ) {
329- ConstantInt* andValueConstant = dyn_cast<ConstantInt>( VI->getOperand ( 1 ) );
330- // We handle "redundant values", so those which bits enable all of
331- // 32 lanes, so 31, 63 (spotted in nature), 127, 255 etc.
332- if ( andValueConstant && (( andValueConstant->getZExtValue () & 31 ) != 31 ) ) {
333- return nullptr ;
372+
373+ // Check for simdLaneId
374+ auto * simdLaneCall = dyn_cast<CallInst>(simdLaneOp);
375+ if (simdLaneCall) {
376+ Function* simdIdF = simdLaneCall->getCalledFunction ();
377+ if (simdIdF &&
378+ GenISAIntrinsic::getIntrinsicID (simdIdF) == GenISAIntrinsic::GenISA_simdLaneId) {
379+
380+ // Pattern matched - create QuadBroadcast
381+ IRBuilder<> builder (I);
382+
383+ Function* quadBroadcastFunc = GenISAIntrinsic::getDeclaration (
384+ builder.GetInsertBlock ()->getParent ()->getParent (),
385+ GenISAIntrinsic::GenISA_QuadBroadcast,
386+ I->getType ());
387+
388+ Value* result = builder.CreateCall (quadBroadcastFunc,
389+ { I->getOperand (0 ), builder.getInt32 (laneIdx) },
390+ " quadBroadcast" );
391+
392+ I->replaceAllUsesWith (result);
393+ I->eraseFromParent ();
394+ return ;
395+ }
334396 }
335- V = VI->getOperand ( 0 ); // Skip over and
336- } else {
337- return nullptr ; // Not a zext, and, or the specified opcode
338397 }
339398 }
340- return nullptr ; // unreachable
341- };
399+ }
400+ }
401+
402+ // Try ShuffleXor pattern if helper lanes = 0
403+ if (enableHelperLanes->getZExtValue () != 0 ) {
404+ return ;
405+ }
342406
343- Instruction* xorInst = getInstructionIgnoringAndZext ( I-> getOperand ( 1 ) , Instruction::Xor );
344- if ( !xorInst )
407+ Instruction* xorInst = getInstructionIgnoringAndZext (indexOp , Instruction::Xor);
408+ if ( !xorInst)
345409 return ;
346410
347- auto xorOperand = xorInst->getOperand ( 0 );
348- auto xorValueConstant = dyn_cast<ConstantInt> ( xorInst->getOperand ( 1 ) );
349- if ( !xorValueConstant )
411+ auto xorOperand = xorInst->getOperand (0 );
412+ auto xorValueConstant = dyn_cast<ConstantInt>( xorInst->getOperand (1 ) );
413+ if ( !xorValueConstant)
350414 return ;
351415
352416 uint64_t xorValue = xorValueConstant->getZExtValue ();
353- if ( xorValue >= 16 )
354- {
417+ if (xorValue >= 16 ) {
355418 // currently not supported in the emitter
356419 return ;
357420 }
358421
359- auto simdLaneCandidate = getInstructionIgnoringAndZext ( xorOperand, Instruction::Call );
360-
422+ auto simdLaneCandidate = getInstructionIgnoringAndZext (xorOperand, Instruction::Call);
361423 if (!simdLaneCandidate)
362424 return ;
363425
364- CallInst* CI = cast<CallInst>( simdLaneCandidate );
426+ CallInst* CI = cast<CallInst>(simdLaneCandidate);
365427 Function* simdIdF = CI->getCalledFunction ();
366- if ( !simdIdF || GenISAIntrinsic::getIntrinsicID ( simdIdF ) != GenISAIntrinsic::GenISA_simdLaneId)
428+ if ( !simdIdF || GenISAIntrinsic::getIntrinsicID (simdIdF) != GenISAIntrinsic::GenISA_simdLaneId)
367429 return ;
368430
369- // since we didn't return earlier, pattern is found
370-
431+ // ShuffleXor pattern found
371432 auto insertShuffleXor = [](IRBuilder<>& builder,
372- Value* value,
373- uint32_t xorValue)
374- {
375- Function* simdShuffleXorFunc = GenISAIntrinsic::getDeclaration (
376- builder.GetInsertBlock ()->getParent ()->getParent (),
377- GenISAIntrinsic::GenISA_simdShuffleXor,
378- value->getType ());
379-
380- return builder.CreateCall (simdShuffleXorFunc,
381- { value, builder.getInt32 (xorValue) }, " simdShuffleXor" );
382- };
433+ Value* value,
434+ uint32_t xorValue) {
435+ Function* simdShuffleXorFunc = GenISAIntrinsic::getDeclaration (
436+ builder.GetInsertBlock ()->getParent ()->getParent (),
437+ GenISAIntrinsic::GenISA_simdShuffleXor,
438+ value->getType ());
439+
440+ return builder.CreateCall (simdShuffleXorFunc,
441+ { value, builder.getInt32 (xorValue) }, " simdShuffleXor" );
442+ };
383443
384444 Value* value = I->getOperand (0 );
385445 IRBuilder<> builder (I);
0 commit comments