@@ -118,11 +118,21 @@ Element Tensor::get(const Index &index) const {
118118 getSizeInBytes (elementType) * flattenIndex (getShape (), index);
119119
120120 // Handle floating-point types.
121+ if (elementType.isFloat8E3M4 ()) {
122+ auto elementData = reinterpret_cast <const uint8_t *>(elementPtr);
123+ return Element (elementType, APFloat (llvm::APFloatBase::Float8E3M4 (),
124+ APInt (8 , *elementData)));
125+ }
121126 if (elementType.isFloat8E4M3B11FNUZ ()) {
122127 auto elementData = reinterpret_cast <const uint8_t *>(elementPtr);
123128 return Element (elementType, APFloat (llvm::APFloatBase::Float8E4M3B11FNUZ (),
124129 APInt (8 , *elementData)));
125130 }
131+ if (elementType.isFloat8E4M3 ()) {
132+ auto elementData = reinterpret_cast <const uint8_t *>(elementPtr);
133+ return Element (elementType, APFloat (llvm::APFloatBase::Float8E4M3 (),
134+ APInt (8 , *elementData)));
135+ }
126136 if (elementType.isFloat8E4M3FN ()) {
127137 auto elementData = reinterpret_cast <const uint8_t *>(elementPtr);
128138 return Element (elementType, APFloat (llvm::APFloatBase::Float8E4M3FN (),
@@ -252,7 +262,8 @@ void Tensor::set(const Index &index, const Element &element) {
252262 getSizeInBytes (elementType) * flattenIndex (getShape (), index);
253263
254264 // Handle floating-point types.
255- if (elementType.isFloat8E4M3B11FNUZ () || elementType.isFloat8E4M3FN () ||
265+ if (elementType.isFloat8E3M4 () || elementType.isFloat8E4M3B11FNUZ () ||
266+ elementType.isFloat8E4M3 () || elementType.isFloat8E4M3FN () ||
256267 elementType.isFloat8E4M3FNUZ () || elementType.isFloat8E5M2 () ||
257268 elementType.isFloat8E5M2FNUZ ()) {
258269 auto elementData = reinterpret_cast <uint8_t *>(elementPtr);
@@ -446,17 +457,18 @@ Tensor makeTensor(DenseElementsAttr attr) {
446457 auto elementType = type.getElementType ();
447458
448459 // Handle floating-point types.
449- if (elementType.isFloat8E4M3B11FNUZ () || elementType.isFloat8E4M3FN () ||
460+ if (elementType.isFloat8E3M4 () || elementType.isFloat8E4M3B11FNUZ () ||
461+ elementType.isFloat8E4M3 () || elementType.isFloat8E4M3FN () ||
450462 elementType.isFloat8E4M3FNUZ () || elementType.isFloat8E5M2 () ||
451463 elementType.isFloat8E5M2FNUZ ()) {
452464 auto floatValues = llvm::map_to_vector (
453465 attr.getValues <APFloat>(), [&](APFloat value) -> uint8_t {
454466 return value.bitcastToAPInt ().getZExtValue ();
455467 });
456468
457- // For f8E4M3B11FNUZ, f8E4M3FN, f8E4M3FNUZ, f8E5M2, and f8E5M2FNUZ
458- // floating-point types, we use uint8_t as their storage type because there
459- // are no builtin types for those.
469+ // For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and
470+ // f8E5M2FNUZ floating-point types, we use uint8_t as their storage type
471+ // because there are no builtin types for those.
460472 return Tensor (type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t >(
461473 floatValues));
462474 }
0 commit comments