Skip to content

Commit 37487a8

Browse files
authored
1 parent 830b978 commit 37487a8

File tree

3 files changed

+15
-35
lines changed

3 files changed

+15
-35
lines changed

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "95c2d798148f12565dd4c9ddc753d196e47f230f"
20+
LLVM_COMMIT = "01d233ff403823389f8480897e41aea84ecbb3d3"
2121

22-
LLVM_SHA256 = "f11e5bbf17d50ff31addc9e1737d64e64a144fce928166de5878c72a1efcf9b4"
22+
LLVM_SHA256 = "283a1d9c251d5028ae78f7a659816588fedaa6a8ba5733bee7249724fb3ed2bc"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
95c2d798148f12565dd4c9ddc753d196e47f230f
1+
01d233ff403823389f8480897e41aea84ecbb3d3

stablehlo/reference/Tensor.cpp

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -155,42 +155,22 @@ Element Tensor::get(const Index &index) const {
155155
// integer variants.
156156
if (isSupportedIntegerType(elementType)) {
157157
IntegerType intTy = cast<IntegerType>(elementType);
158-
159-
if (elementType.isSignlessInteger(2) || elementType.isSignlessInteger(4) ||
160-
elementType.isSignlessInteger(8)) {
161-
auto elementData = reinterpret_cast<const int8_t *>(elementPtr);
162-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
163-
intTy.isSignedInteger()));
164-
} else if (elementType.isSignlessInteger(16)) {
165-
auto elementData = reinterpret_cast<const int16_t *>(elementPtr);
166-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
167-
intTy.isSignedInteger()));
168-
} else if (elementType.isSignlessInteger(32)) {
169-
auto elementData = reinterpret_cast<const int32_t *>(elementPtr);
170-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
171-
intTy.isSignedInteger()));
172-
} else if (elementType.isSignlessInteger(64)) {
173-
auto elementData = reinterpret_cast<const int64_t *>(elementPtr);
174-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
175-
intTy.isSignedInteger()));
176-
} else if (elementType.isUnsignedInteger(2) ||
177-
elementType.isUnsignedInteger(4) ||
178-
elementType.isUnsignedInteger(8)) {
158+
const unsigned int bitwidth = intTy.getWidth();
159+
if (bitwidth == 2 || bitwidth == 4 || bitwidth == 8) {
179160
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
180-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
181-
intTy.isSignedInteger()));
182-
} else if (elementType.isUnsignedInteger(16)) {
161+
// Set implicitTrunc to ignore garbage bits on 2-bit and 4-bit types.
162+
const bool implicitTrunc = bitwidth == 2 || bitwidth == 4;
163+
return Element(elementType, APInt(bitwidth, *elementData,
164+
/*isSigned=*/false, implicitTrunc));
165+
} else if (bitwidth == 16) {
183166
auto elementData = reinterpret_cast<const uint16_t *>(elementPtr);
184-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
185-
intTy.isSignedInteger()));
186-
} else if (elementType.isUnsignedInteger(32)) {
167+
return Element(elementType, APInt(bitwidth, *elementData));
168+
} else if (bitwidth == 32) {
187169
auto elementData = reinterpret_cast<const uint32_t *>(elementPtr);
188-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
189-
intTy.isSignedInteger()));
190-
} else if (elementType.isUnsignedInteger(64)) {
170+
return Element(elementType, APInt(bitwidth, *elementData));
171+
} else if (bitwidth == 64) {
191172
auto elementData = reinterpret_cast<const uint64_t *>(elementPtr);
192-
return Element(elementType, APInt(intTy.getWidth(), *elementData,
193-
intTy.isSignedInteger()));
173+
return Element(elementType, APInt(bitwidth, *elementData));
194174
}
195175
}
196176

0 commit comments

Comments
 (0)