@@ -127,10 +127,12 @@ def : Pat<(AArch64_requires_za_save), (RequiresZASavePseudo)>;
127127def SDT_AArch64RDSVL : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
128128def AArch64rdsvl : SDNode<"AArch64ISD::RDSVL", SDT_AArch64RDSVL>;
129129
130- def sme_cntsb_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
131- def sme_cntsh_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
132- def sme_cntsw_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
133- def sme_cntsd_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
130+ def sme_cntsb_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 8>">;
131+ def sme_cntsh_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 4>">;
132+ def sme_cntsw_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 2>">;
133+ def sme_cntsd_mul_imm : ComplexPattern<i64, 1, "SelectRDVLImm<1, 31, 1>">;
134+
135+ def sme_cnts_shl_imm : ComplexPattern<i64, 1, "SelectRDSVLShiftImm<1, 31>">;
134136
135137let Predicates = [HasSMEandIsNonStreamingSafe] in {
136138def RDSVLI_XI : sve_int_read_vl_a<0b0, 0b11111, "rdsvl", /*streaming_sve=*/0b1>;
@@ -140,21 +142,21 @@ def ADDSVL_XXI : sve_int_arith_vl<0b0, "addsvl", /*streaming_sve=*/0b1>;
140142def : Pat<(AArch64rdsvl (i32 simm6_32b:$imm)), (RDSVLI_XI simm6_32b:$imm)>;
141143
142144// e.g. cntsb() * imm
143- def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_imm i64:$imm))),
145+ def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsb_mul_imm i64:$imm))),
144146 (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
145- def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_imm i64:$imm))),
147+ def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsh_mul_imm i64:$imm))),
146148 (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 1, 63)>;
147- def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_imm i64:$imm))),
149+ def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsw_mul_imm i64:$imm))),
148150 (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 2, 63)>;
149- def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_imm i64:$imm))),
151+ def : Pat<(i64 (mul (int_aarch64_sme_cntsd), (sme_cntsd_mul_imm i64:$imm))),
150152 (UBFMXri (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm)), 3, 63)>;
151153
152- // e.g. cntsb()
153- def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
154- def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
155- def: Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 3))), (RDSVLI_XI 1)>;
154+ def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (sme_cnts_shl_imm i64:$imm))),
155+ (RDSVLI_XI (!cast<SDNodeXForm>("trunc_imm") $imm))>;
156156
157- // Generic pattern for cntsd (RDSVL #1 >> 3)
157+ // cntsh, cntsw, cntsd
158+ def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 2))), (UBFMXri (RDSVLI_XI 1), 1, 63)>;
159+ def : Pat<(i64 (shl (int_aarch64_sme_cntsd), (i64 1))), (UBFMXri (RDSVLI_XI 1), 2, 63)>;
158160def : Pat<(i64 (int_aarch64_sme_cntsd)), (UBFMXri (RDSVLI_XI 1), 3, 63)>;
159161}
160162
0 commit comments