diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 295506393a1c4..4d60ad4459ad6 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2294,6 +2294,12 @@ LLVM_ABI APInt mulhs(const APInt &C1, const APInt &C2); /// Returns the high N bits of the multiplication result. LLVM_ABI APInt mulhu(const APInt &C1, const APInt &C2); +/// Performs (2*N)-bit multiplication on sign-extended operands. +LLVM_ABI APInt mulsExtended(const APInt &C1, const APInt &C2); + +/// Performs (2*N)-bit multiplication on zero-extended operands. +LLVM_ABI APInt muluExtended(const APInt &C1, const APInt &C2); + /// Compute X^N for N>=0. /// 0^0 is supported and returns 1. LLVM_ABI APInt pow(const APInt &X, int64_t N); diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 954af7fff92a8..1547f48bc7ac0 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -3136,6 +3136,22 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) { return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth()); } +APInt APIntOps::mulsExtended(const APInt &C1, const APInt &C2) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.sext(FullWidth); + APInt C2Ext = C2.sext(FullWidth); + return C1Ext * C2Ext; +} + +APInt APIntOps::muluExtended(const APInt &C1, const APInt &C2) { + assert(C1.getBitWidth() == C2.getBitWidth() && "Unequal bitwidths"); + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.zext(FullWidth); + APInt C2Ext = C2.zext(FullWidth); + return C1Ext * C2Ext; +} + APInt APIntOps::pow(const APInt &X, int64_t N) { assert(N >= 0 && "negative exponents not supported."); APInt Acc = APInt(X.getBitWidth(), 1); diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 4741c7bcc140f..acc6a09acbf4d 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -3103,6 +3103,53 @@ TEST(APIntOpsTest, Mulh) { EXPECT_EQ(APInt(128, "FFEB498812C66C68D4552DB89B8EBF8F", 16), i128Res); } +TEST(APIntOpsTest, muli) { + APInt u32a(32, 0x0001'E235); + APInt u32b(32, 0xF623'55AD); + EXPECT_EQ(0x0001'CFA1'7CA0'76D1, APIntOps::muluExtended(u32a, u32b)); + + APInt u64a(64, 0x1234'5678'90AB'CDEF); + APInt u64b(64, 0xFEDC'BA09'8765'4321); + EXPECT_EQ(APInt(128, "121FA000A3723A57C24A442FE55618CF", 16), + APIntOps::muluExtended(u64a, u64b)); + + APInt u128a(128, "1234567890ABCDEF1234567890ABCDEF", 16); + APInt u128b(128, "FEDCBA0987654321FEDCBA0987654321", 16); + EXPECT_EQ( + APInt(256, + "121FA000A3723A57E68984312C3A8D7E96B428606E1E6BF5C24A442FE55618CF", + 16), + APIntOps::muluExtended(u128a, u128b)); + + APInt s32a(32, 0x1234'5678); + APInt s32b(32, 0x10AB'CDEF); + APInt s32c(32, 0xFEDC'BA09); + EXPECT_EQ(0x012F'7D02'2A42'D208, APIntOps::mulsExtended(s32a, s32b)); + EXPECT_EQ(0xFFEB'4988'09CA'3A38, APIntOps::mulsExtended(s32a, s32c)); + + APInt s64a(64, 0x1234'5678'90AB'CDEF); + APInt s64b(64, 0x1234'5678'90FE'DCBA); + APInt s64c(64, 0xFEDC'BA09'8765'4321); + EXPECT_EQ(APInt(128, "014B66DC328E10C1FB99704184EF03A6", 16), + APIntOps::mulsExtended(s64a, s64b)); + EXPECT_EQ(APInt(128, "FFEB498812C66C68C24A442FE55618CF", 16), + APIntOps::mulsExtended(s64a, s64c)); + + APInt s128a(128, "1234567890ABCDEF1234567890ABCDEF", 16); + APInt s128b(128, "1234567890FEDCBA1234567890FEDCBA", 16); + APInt s128c(128, "FEDCBA0987654321FEDCBA0987654321", 16); + EXPECT_EQ( + APInt(256, + "014B66DC328E10C1FE303DF9EA0B2529F87E475F3C6C180DFB99704184EF03A6", + 16), + APIntOps::mulsExtended(s128a, s128b)); + EXPECT_EQ( + APInt(256, + "FFEB498812C66C68D4552DB89B8EBF8F96B428606E1E6BF5C24A442FE55618CF", + 16), + APIntOps::mulsExtended(s128a, s128c)); +} + TEST(APIntTest, RoundingUDiv) { for (uint64_t Ai = 1; Ai <= 255; Ai++) { APInt A(8, Ai);