diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index cfb4af391b540..c169ab25b2106 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3076,10 +3076,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_aesd: case Intrinsic::arm_neon_aese: case Intrinsic::aarch64_crypto_aesd: - case Intrinsic::aarch64_crypto_aese: { + case Intrinsic::aarch64_crypto_aese: + case Intrinsic::aarch64_sve_aesd: + case Intrinsic::aarch64_sve_aese: { Value *DataArg = II->getArgOperand(0); Value *KeyArg = II->getArgOperand(1); + // Accept zero on either operand. + if (!match(KeyArg, m_ZeroInt())) + std::swap(KeyArg, DataArg); + // Try to use the builtin XOR in AESE and AESD to eliminate a prior XOR Value *Data, *Key; if (match(KeyArg, m_ZeroInt()) && diff --git a/llvm/test/Transforms/InstCombine/AArch64/aes-intrinsics.ll b/llvm/test/Transforms/InstCombine/AArch64/aes-intrinsics.ll index c6695f17b955b..8c69d0721b738 100644 --- a/llvm/test/Transforms/InstCombine/AArch64/aes-intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/aes-intrinsics.ll @@ -13,6 +13,17 @@ define <16 x i8> @combineXorAeseZeroARM64(<16 x i8> %data, <16 x i8> %key) { ret <16 x i8> %data.aes } +define <16 x i8> @combineXorAeseZeroLhsARM64(<16 x i8> %data, <16 x i8> %key) { +; CHECK-LABEL: define <16 x i8> @combineXorAeseZeroLhsARM64( +; CHECK-SAME: <16 x i8> [[DATA:%.*]], <16 x i8> [[KEY:%.*]]) { +; CHECK-NEXT: [[DATA_AES:%.*]] = tail call <16 x i8> @llvm.aarch64.crypto.aese(<16 x i8> [[DATA]], <16 x i8> [[KEY]]) +; CHECK-NEXT: ret <16 x i8> [[DATA_AES]] +; + %data.xor = xor <16 x i8> %data, %key + %data.aes = tail call <16 x i8> @llvm.aarch64.crypto.aese(<16 x i8> zeroinitializer, <16 x i8> %data.xor) + ret <16 x i8> %data.aes +} + define <16 x i8> @combineXorAeseNonZeroARM64(<16 x i8> %data, <16 x i8> %key) { ; CHECK-LABEL: define <16 x i8> @combineXorAeseNonZeroARM64( ; CHECK-SAME: <16 x i8> [[DATA:%.*]], <16 x i8> [[KEY:%.*]]) { @@ -36,6 +47,17 @@ define <16 x i8> @combineXorAesdZeroARM64(<16 x i8> %data, <16 x i8> %key) { ret <16 x i8> %data.aes } +define <16 x i8> @combineXorAesdZeroLhsARM64(<16 x i8> %data, <16 x i8> %key) { +; CHECK-LABEL: define <16 x i8> @combineXorAesdZeroLhsARM64( +; CHECK-SAME: <16 x i8> [[DATA:%.*]], <16 x i8> [[KEY:%.*]]) { +; CHECK-NEXT: [[DATA_AES:%.*]] = tail call <16 x i8> @llvm.aarch64.crypto.aesd(<16 x i8> [[DATA]], <16 x i8> [[KEY]]) +; CHECK-NEXT: ret <16 x i8> [[DATA_AES]] +; + %data.xor = xor <16 x i8> %data, %key + %data.aes = tail call <16 x i8> @llvm.aarch64.crypto.aesd(<16 x i8> zeroinitializer, <16 x i8> %data.xor) + ret <16 x i8> %data.aes +} + define <16 x i8> @combineXorAesdNonZeroARM64(<16 x i8> %data, <16 x i8> %key) { ; CHECK-LABEL: define <16 x i8> @combineXorAesdNonZeroARM64( ; CHECK-SAME: <16 x i8> [[DATA:%.*]], <16 x i8> [[KEY:%.*]]) { @@ -51,3 +73,51 @@ define <16 x i8> @combineXorAesdNonZeroARM64(<16 x i8> %data, <16 x i8> %key) { declare <16 x i8> @llvm.aarch64.crypto.aese(<16 x i8>, <16 x i8>) #0 declare <16 x i8> @llvm.aarch64.crypto.aesd(<16 x i8>, <16 x i8>) #0 +; SVE + +define @combineXorAeseZeroLhsSVE( %data, %key) { +; CHECK-LABEL: define @combineXorAeseZeroLhsSVE( +; CHECK-SAME: [[DATA:%.*]], [[KEY:%.*]]) { +; CHECK-NEXT: [[DATA_AES:%.*]] = tail call @llvm.aarch64.sve.aese( [[DATA]], [[KEY]]) +; CHECK-NEXT: ret [[DATA_AES]] +; + %data.xor = xor %data, %key + %data.aes = tail call @llvm.aarch64.sve.aese( zeroinitializer, %data.xor) + ret %data.aes +} + +define @combineXorAeseZeroRhsSVE( %data, %key) { +; CHECK-LABEL: define @combineXorAeseZeroRhsSVE( +; CHECK-SAME: [[DATA:%.*]], [[KEY:%.*]]) { +; CHECK-NEXT: [[DATA_AES:%.*]] = tail call @llvm.aarch64.sve.aese( [[DATA]], [[KEY]]) +; CHECK-NEXT: ret [[DATA_AES]] +; + %data.xor = xor %data, %key + %data.aes = tail call @llvm.aarch64.sve.aese( %data.xor, zeroinitializer) + ret %data.aes +} + +define @combineXorAesdZeroLhsSVE( %data, %key) { +; CHECK-LABEL: define @combineXorAesdZeroLhsSVE( +; CHECK-SAME: [[DATA:%.*]], [[KEY:%.*]]) { +; CHECK-NEXT: [[DATA_AES:%.*]] = tail call @llvm.aarch64.sve.aesd( [[DATA]], [[KEY]]) +; CHECK-NEXT: ret [[DATA_AES]] +; + %data.xor = xor %data, %key + %data.aes = tail call @llvm.aarch64.sve.aesd( zeroinitializer, %data.xor) + ret %data.aes +} + +define @combineXorAesdZeroRhsSVE( %data, %key) { +; CHECK-LABEL: define @combineXorAesdZeroRhsSVE( +; CHECK-SAME: [[DATA:%.*]], [[KEY:%.*]]) { +; CHECK-NEXT: [[DATA_AES:%.*]] = tail call @llvm.aarch64.sve.aesd( [[DATA]], [[KEY]]) +; CHECK-NEXT: ret [[DATA_AES]] +; + %data.xor = xor %data, %key + %data.aes = tail call @llvm.aarch64.sve.aesd( %data.xor, zeroinitializer) + ret %data.aes +} + +declare @llvm.aarch64.sve.aese(, ) #0 +declare @llvm.aarch64.sve.aesd(, ) #0