@@ -32106,6 +32106,15 @@ bool GenTree::CanDivOrModPossiblyOverflow(Compiler* comp) const
32106
32106
return true;
32107
32107
}
32108
32108
32109
+ //------------------------------------------------------------------------
32110
+ // gtFoldExprHWIntrinsic: Attempt to fold a HWIntrinsic
32111
+ //
32112
+ // Arguments:
32113
+ // tree - HWIntrinsic to fold
32114
+ //
32115
+ // Return Value:
32116
+ // folded expression if it could be folded, else the original tree
32117
+ //
32109
32118
#if defined(FEATURE_HW_INTRINSICS)
32110
32119
GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
32111
32120
{
@@ -32249,7 +32258,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
32249
32258
// We shouldn't find AND_NOT nodes since it should only be produced in lowering
32250
32259
assert(oper != GT_AND_NOT);
32251
32260
32252
- #if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH)
32261
+ #ifdef FEATURE_MASKED_HW_INTRINSICS
32262
+ #ifdef TARGET_XARCH
32253
32263
if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper))
32254
32264
{
32255
32265
// Comparisons that produce masks lead to more verbose trees than
@@ -32367,7 +32377,75 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
32367
32377
}
32368
32378
}
32369
32379
}
32370
- #endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH
32380
+ #elif defined(TARGET_ARM64)
32381
+ // Check if the tree can be folded into a mask variant
32382
+ if (HWIntrinsicInfo::HasAllMaskVariant(tree->GetHWIntrinsicId()))
32383
+ {
32384
+ NamedIntrinsic maskVariant = HWIntrinsicInfo::GetMaskVariant(tree->GetHWIntrinsicId());
32385
+
32386
+ assert(opCount == (size_t)HWIntrinsicInfo::lookupNumArgs(maskVariant));
32387
+
32388
+ // Check all operands are valid
32389
+ bool canFold = true;
32390
+ if (ni == NI_Sve_ConditionalSelect)
32391
+ {
32392
+ assert(varTypeIsMask(op1));
32393
+ canFold = (op2->OperIsConvertMaskToVector() && op3->OperIsConvertMaskToVector());
32394
+ }
32395
+ else
32396
+ {
32397
+ for (size_t i = 1; i <= opCount && canFold; i++)
32398
+ {
32399
+ canFold &= tree->Op(i)->OperIsConvertMaskToVector();
32400
+ }
32401
+ }
32402
+
32403
+ if (canFold)
32404
+ {
32405
+ // Convert all the operands to masks
32406
+ for (size_t i = 1; i <= opCount; i++)
32407
+ {
32408
+ if (tree->Op(i)->OperIsConvertMaskToVector())
32409
+ {
32410
+ // Replace with op1.
32411
+ tree->Op(i) = tree->Op(i)->AsHWIntrinsic()->Op(1);
32412
+ }
32413
+ else if (tree->Op(i)->IsVectorZero())
32414
+ {
32415
+ // Replace the vector of zeroes with a mask of zeroes.
32416
+ tree->Op(i) = gtNewSimdFalseMaskByteNode();
32417
+ tree->Op(i)->SetMorphed(this);
32418
+ }
32419
+ assert(varTypeIsMask(tree->Op(i)));
32420
+ }
32421
+
32422
+ // Switch to the mask variant
32423
+ switch (opCount)
32424
+ {
32425
+ case 1:
32426
+ tree->ResetHWIntrinsicId(maskVariant, tree->Op(1));
32427
+ break;
32428
+ case 2:
32429
+ tree->ResetHWIntrinsicId(maskVariant, tree->Op(1), tree->Op(2));
32430
+ break;
32431
+ case 3:
32432
+ tree->ResetHWIntrinsicId(maskVariant, this, tree->Op(1), tree->Op(2), tree->Op(3));
32433
+ break;
32434
+ default:
32435
+ unreached();
32436
+ }
32437
+
32438
+ tree->gtType = TYP_MASK;
32439
+ tree->SetMorphed(this);
32440
+ tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic();
32441
+ tree->SetMorphed(this);
32442
+ op1 = tree->Op(1);
32443
+ op2 = nullptr;
32444
+ op3 = nullptr;
32445
+ }
32446
+ }
32447
+ #endif // TARGET_ARM64
32448
+ #endif // FEATURE_MASKED_HW_INTRINSICS
32371
32449
32372
32450
GenTree* cnsNode = nullptr;
32373
32451
GenTree* otherNode = nullptr;
@@ -33754,7 +33832,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
33754
33832
// op2 = op2 & op1
33755
33833
op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon());
33756
33834
33757
- // op3 = op2 & ~op1
33835
+ // op3 = op3 & ~op1
33758
33836
op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon());
33759
33837
33760
33838
// op2 = op2 | op3
@@ -33767,8 +33845,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
33767
33845
33768
33846
#if defined(TARGET_ARM64)
33769
33847
case NI_Sve_ConditionalSelect:
33848
+ case NI_Sve_ConditionalSelect_Predicates:
33770
33849
{
33771
- assert(!varTypeIsMask(retType));
33772
33850
assert(varTypeIsMask(op1));
33773
33851
33774
33852
if (cnsNode != op1)
@@ -33797,10 +33875,11 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
33797
33875
33798
33876
if (op2->IsCnsVec() && op3->IsCnsVec())
33799
33877
{
33878
+ assert(ni == NI_Sve_ConditionalSelect);
33800
33879
assert(op2->gtType == TYP_SIMD16);
33801
33880
assert(op3->gtType == TYP_SIMD16);
33802
33881
33803
- simd16_t op1SimdVal;
33882
+ simd16_t op1SimdVal = {} ;
33804
33883
EvaluateSimdCvtMaskToVector<simd16_t>(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal);
33805
33884
33806
33885
// op2 = op2 & op1
@@ -33809,7 +33888,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
33809
33888
op1SimdVal);
33810
33889
op2->AsVecCon()->gtSimd16Val = result;
33811
33890
33812
- // op3 = op2 & ~op1
33891
+ // op3 = op3 & ~op1
33813
33892
result = {};
33814
33893
EvaluateBinarySimd<simd16_t>(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val,
33815
33894
op1SimdVal);
@@ -33820,6 +33899,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
33820
33899
33821
33900
resultNode = op2;
33822
33901
}
33902
+ else if (op2->IsCnsMsk() && op3->IsCnsMsk())
33903
+ {
33904
+ assert(ni == NI_Sve_ConditionalSelect_Predicates);
33905
+
33906
+ // op2 = op2 & op1
33907
+ simdmask_t result = {};
33908
+ EvaluateBinaryMask<simd16_t>(GT_AND, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
33909
+ op1->AsMskCon()->gtSimdMaskVal);
33910
+ op2->AsMskCon()->gtSimdMaskVal = result;
33911
+
33912
+ // op3 = op3 & ~op1
33913
+ result = {};
33914
+ EvaluateBinaryMask<simd16_t>(GT_AND_NOT, false, simdBaseType, &result,
33915
+ op3->AsMskCon()->gtSimdMaskVal, op1->AsMskCon()->gtSimdMaskVal);
33916
+ op3->AsMskCon()->gtSimdMaskVal = result;
33917
+
33918
+ // op2 = op2 | op3
33919
+ result = {};
33920
+ EvaluateBinaryMask<simd16_t>(GT_OR, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
33921
+ op3->AsMskCon()->gtSimdMaskVal);
33922
+ op2->AsMskCon()->gtSimdMaskVal = result;
33923
+
33924
+ resultNode = op2;
33925
+ }
33823
33926
break;
33824
33927
}
33825
33928
#endif // TARGET_ARM64
0 commit comments