@@ -155,42 +155,22 @@ Element Tensor::get(const Index &index) const {
155155 // integer variants.
156156 if (isSupportedIntegerType (elementType)) {
157157 IntegerType intTy = cast<IntegerType>(elementType);
158-
159- if (elementType.isSignlessInteger (2 ) || elementType.isSignlessInteger (4 ) ||
160- elementType.isSignlessInteger (8 )) {
161- auto elementData = reinterpret_cast <const int8_t *>(elementPtr);
162- return Element (elementType, APInt (intTy.getWidth (), *elementData,
163- intTy.isSignedInteger ()));
164- } else if (elementType.isSignlessInteger (16 )) {
165- auto elementData = reinterpret_cast <const int16_t *>(elementPtr);
166- return Element (elementType, APInt (intTy.getWidth (), *elementData,
167- intTy.isSignedInteger ()));
168- } else if (elementType.isSignlessInteger (32 )) {
169- auto elementData = reinterpret_cast <const int32_t *>(elementPtr);
170- return Element (elementType, APInt (intTy.getWidth (), *elementData,
171- intTy.isSignedInteger ()));
172- } else if (elementType.isSignlessInteger (64 )) {
173- auto elementData = reinterpret_cast <const int64_t *>(elementPtr);
174- return Element (elementType, APInt (intTy.getWidth (), *elementData,
175- intTy.isSignedInteger ()));
176- } else if (elementType.isUnsignedInteger (2 ) ||
177- elementType.isUnsignedInteger (4 ) ||
178- elementType.isUnsignedInteger (8 )) {
158+ const unsigned int bitwidth = intTy.getWidth ();
159+ if (bitwidth == 2 || bitwidth == 4 || bitwidth == 8 ) {
179160 auto elementData = reinterpret_cast <const uint8_t *>(elementPtr);
180- return Element (elementType, APInt (intTy.getWidth (), *elementData,
181- intTy.isSignedInteger ()));
182- } else if (elementType.isUnsignedInteger (16 )) {
161+ // Set implicitTrunc to ignore garbage bits on 2-bit and 4-bit types.
162+ const bool implicitTrunc = bitwidth == 2 || bitwidth == 4 ;
163+ return Element (elementType, APInt (bitwidth, *elementData,
164+ /* isSigned=*/ false , implicitTrunc));
165+ } else if (bitwidth == 16 ) {
183166 auto elementData = reinterpret_cast <const uint16_t *>(elementPtr);
184- return Element (elementType, APInt (intTy.getWidth (), *elementData,
185- intTy.isSignedInteger ()));
186- } else if (elementType.isUnsignedInteger (32 )) {
167+ return Element (elementType, APInt (bitwidth, *elementData));
168+ } else if (bitwidth == 32 ) {
187169 auto elementData = reinterpret_cast <const uint32_t *>(elementPtr);
188- return Element (elementType, APInt (intTy.getWidth (), *elementData,
189- intTy.isSignedInteger ()));
190- } else if (elementType.isUnsignedInteger (64 )) {
170+ return Element (elementType, APInt (bitwidth, *elementData));
171+ } else if (bitwidth == 64 ) {
191172 auto elementData = reinterpret_cast <const uint64_t *>(elementPtr);
192- return Element (elementType, APInt (intTy.getWidth (), *elementData,
193- intTy.isSignedInteger ()));
173+ return Element (elementType, APInt (bitwidth, *elementData));
194174 }
195175 }
196176
0 commit comments