@@ -39,6 +39,115 @@ inline int8_t upc(int8_t h) {
3939 return h | (-((h & (1 << 3 )) >> 3 ) & (-8 ));
4040}
4141
42+ // OpenVINO core semantics for e4m3: mag = LUT[bits & 0x7F], sign by bit7.
43+ // Implement with AVX2 gather from 128-entry float LUT, then cvtps_ph.
44+ inline __m256i f8e4m3tof16 (__m128i vf8) {
45+ alignas (32 ) static constexpr float k_f8e4m3_lut[128 ] = {
46+ 0 .0f , 0 .001953125f , 0 .00390625f , 0 .005859375f ,
47+ 0 .0078125f , 0 .009765625f , 0 .01171875f , 0 .013671875f ,
48+ 0 .015625f , 0 .017578125f , 0 .01953125f , 0 .021484375f ,
49+ 0 .0234375f , 0 .025390625f , 0 .02734375f , 0 .029296875f ,
50+ 0 .03125f , 0 .03515625f , 0 .0390625f , 0 .04296875f ,
51+ 0 .046875f , 0 .05078125f , 0 .0546875f , 0 .05859375f ,
52+ 0 .0625f , 0 .0703125f , 0 .078125f , 0 .0859375f ,
53+ 0 .09375f , 0 .1015625f , 0 .109375f , 0 .1171875f ,
54+ 0 .125f , 0 .140625f , 0 .15625f , 0 .171875f ,
55+ 0 .1875f , 0 .203125f , 0 .21875f , 0 .234375f ,
56+ 0 .25f , 0 .28125f , 0 .3125f , 0 .34375f ,
57+ 0 .375f , 0 .40625f , 0 .4375f , 0 .46875f ,
58+ 0 .5f , 0 .5625f , 0 .625f , 0 .6875f ,
59+ 0 .75f , 0 .8125f , 0 .875f , 0 .9375f ,
60+ 1 .0f , 1 .125f , 1 .25f , 1 .375f ,
61+ 1 .5f , 1 .625f , 1 .75f , 1 .875f ,
62+ 2 .0f , 2 .25f , 2 .5f , 2 .75f ,
63+ 3 .0f , 3 .25f , 3 .5f , 3 .75f ,
64+ 4 .0f , 4 .5f , 5 .0f , 5 .5f ,
65+ 6 .0f , 6 .5f , 7 .0f , 7 .5f ,
66+ 8 .0f , 9 .0f , 10 .0f , 11 .0f ,
67+ 12 .0f , 13 .0f , 14 .0f , 15 .0f ,
68+ 16 .0f , 18 .0f , 20 .0f , 22 .0f ,
69+ 24 .0f , 26 .0f , 28 .0f , 30 .0f ,
70+ 32 .0f , 36 .0f , 40 .0f , 44 .0f ,
71+ 48 .0f , 52 .0f , 56 .0f , 60 .0f ,
72+ 64 .0f , 72 .0f , 80 .0f , 88 .0f ,
73+ 96 .0f , 104 .0f , 112 .0f , 120 .0f ,
74+ 128 .0f , 144 .0f , 160 .0f , 176 .0f ,
75+ 192 .0f , 208 .0f , 224 .0f , 240 .0f ,
76+ 256 .0f , 288 .0f , 320 .0f , 352 .0f ,
77+ 384 .0f , 416 .0f , 448 .0f , std::numeric_limits<float >::quiet_NaN (),
78+ };
79+
80+ __m256i u16 = _mm256_cvtepu8_epi16 (vf8);
81+
82+ const __m256i sign_m = _mm256_set1_epi16 (0x80 );
83+ const __m256i idx_m = _mm256_set1_epi16 (0x7F );
84+
85+ __m256i sign16 = _mm256_and_si256 (u16 , sign_m);
86+ __m256i idx16 = _mm256_and_si256 (u16 , idx_m);
87+
88+ __m256i sign_nonzero = _mm256_cmpgt_epi16 (sign16, _mm256_setzero_si256 ());
89+
90+ __m128i idx16_lo = _mm256_castsi256_si128 (idx16);
91+ __m128i idx16_hi = _mm256_extracti128_si256 (idx16, 1 );
92+ __m256i idx32_lo = _mm256_cvtepu16_epi32 (idx16_lo);
93+ __m256i idx32_hi = _mm256_cvtepu16_epi32 (idx16_hi);
94+
95+ __m256 mag_lo = _mm256_i32gather_ps (k_f8e4m3_lut, idx32_lo, 4 );
96+ __m256 mag_hi = _mm256_i32gather_ps (k_f8e4m3_lut, idx32_hi, 4 );
97+
98+ __m256i sign32_lo = _mm256_slli_epi32 (_mm256_cvtepi16_epi32 (_mm256_castsi256_si128 (sign_nonzero)), 31 );
99+ __m256i sign32_hi = _mm256_slli_epi32 (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (sign_nonzero, 1 )), 31 );
100+
101+ __m256 v_lo = _mm256_xor_ps (mag_lo, _mm256_castsi256_ps (sign32_lo));
102+ __m256 v_hi = _mm256_xor_ps (mag_hi, _mm256_castsi256_ps (sign32_hi));
103+
104+ __m128i h0 = _mm256_cvtps_ph (v_lo, _MM_FROUND_TO_NEAREST_INT);
105+ __m128i h1 = _mm256_cvtps_ph (v_hi, _MM_FROUND_TO_NEAREST_INT);
106+ return _mm256_set_m128i (h1, h0);
107+ }
108+
109+ inline __m256i f8e5m2tof16 (__m128i vf8) {
110+ const __m256i b16 = _mm256_cvtepu8_epi16 (vf8);
111+ return _mm256_slli_epi16 (b16, 8 );
112+ }
113+
114+ inline __m256i f8e8m0tof16 (__m128i vf8) {
115+ __m256i u16 = _mm256_cvtepu8_epi16 (vf8);
116+
117+ __m128i u16_lo = _mm256_castsi256_si128 (u16 );
118+ __m128i u16_hi = _mm256_extracti128_si256 (u16 , 1 );
119+ __m256i u32_lo = _mm256_cvtepu16_epi32 (u16_lo);
120+ __m256i u32_hi = _mm256_cvtepu16_epi32 (u16_hi);
121+
122+ const __m256i ff = _mm256_set1_epi32 (0xFF );
123+ const __m256i zz = _mm256_setzero_si256 ();
124+
125+ __m256i is_ff_lo = _mm256_cmpeq_epi32 (u32_lo, ff);
126+ __m256i is_ff_hi = _mm256_cmpeq_epi32 (u32_hi, ff);
127+ __m256i is_00_lo = _mm256_cmpeq_epi32 (u32_lo, zz);
128+ __m256i is_00_hi = _mm256_cmpeq_epi32 (u32_hi, zz);
129+
130+ __m256i fbits_lo = _mm256_slli_epi32 (u32_lo, 23 );
131+ __m256i fbits_hi = _mm256_slli_epi32 (u32_hi, 23 );
132+
133+ const __m256i qnan_bits = _mm256_set1_epi32 (0x7FC00000 );
134+
135+ const __m256i zero_fbits = _mm256_setzero_si256 ();
136+
137+ fbits_lo = _mm256_blendv_epi8 (fbits_lo, zero_fbits, is_00_lo);
138+ fbits_hi = _mm256_blendv_epi8 (fbits_hi, zero_fbits, is_00_hi);
139+
140+ fbits_lo = _mm256_blendv_epi8 (fbits_lo, qnan_bits, is_ff_lo);
141+ fbits_hi = _mm256_blendv_epi8 (fbits_hi, qnan_bits, is_ff_hi);
142+
143+ __m256 f_lo = _mm256_castsi256_ps (fbits_lo);
144+ __m256 f_hi = _mm256_castsi256_ps (fbits_hi);
145+
146+ __m128i h0 = _mm256_cvtps_ph (f_lo, _MM_FROUND_TO_NEAREST_INT);
147+ __m128i h1 = _mm256_cvtps_ph (f_hi, _MM_FROUND_TO_NEAREST_INT);
148+ return _mm256_set_m128i (h1, h0);
149+ }
150+
42151inline int32_t pack_4bit_avx2_reduction (__m256i ymm) {
43152 __m256i mask = _mm256_set1_epi32 (0xF );
44153 ymm = _mm256_and_si256 (ymm, mask);
@@ -1632,3 +1741,167 @@ void ov::npuw::util::XARCH::transpose_f32(const float* src, float* dst, size_t r
16321741 OPENVINO_THROW (" AVX2 support is necessary but it's not enabled!" );
16331742#endif
16341743}
1744+
1745+ void ov::npuw::util::XARCH::unpack_f8f16_scale (const ov::SoPtr<ov::ITensor>& from,
1746+ const ov::SoPtr<ov::ITensor>& scale,
1747+ const ov::SoPtr<ov::ITensor>& to,
1748+ const ov::npuw::util::UnpackOptions& unpack_options) {
1749+ NPUW_ASSERT (from->is_continuous ());
1750+ NPUW_ASSERT (scale->is_continuous ());
1751+ NPUW_ASSERT (to->is_continuous ());
1752+
1753+ const auto from_shape = from->get_shape ();
1754+ const auto scale_shape = scale->get_shape ();
1755+
1756+ NPUW_ASSERT (from->get_size () == to->get_size ());
1757+ NPUW_ASSERT (scale_shape.size () >= 2 );
1758+ NPUW_ASSERT (from_shape[0 ] == scale_shape[0 ]);
1759+ NPUW_ASSERT (scale_shape[1 ] == 1 );
1760+
1761+ const auto ftype = from->get_element_type ();
1762+ NPUW_ASSERT (ftype == ov::element::f8e4m3 || ftype == ov::element::f8e5m2 || ftype == ov::element::f8e8m0);
1763+ NPUW_ASSERT (scale->get_element_type () == ov::element::f32 );
1764+ NPUW_ASSERT (to->get_element_type () == ov::element::f16 );
1765+
1766+ const size_t total = from->get_size (); // total number of f8 elements
1767+ const size_t stotal = scale->get_size (); // number of scale factors
1768+ NPUW_ASSERT (total % stotal == 0 );
1769+ const size_t elemsPerScale = total / stotal; // elements governed by one scale
1770+ NPUW_ASSERT (elemsPerScale > 0 );
1771+
1772+ #if defined(HAVE_AVX2)
1773+ const uint8_t * src = static_cast <uint8_t *>(from->data ());
1774+ const float * scl = static_cast <float *>(scale->data ());
1775+ uint16_t * dst = static_cast <uint16_t *>(to->data ());
1776+
1777+ const size_t VEC = 16 ; // vector width (16 x f8 -> 16 x f16)
1778+ // Vector convert helper: load 16 f8 -> 16 f16, apply uniform scale
1779+ auto convert_block = [&](const uint8_t * bsrc, uint16_t * bdst, float scale_val) {
1780+ __m128i vf8 = _mm_loadu_si128 (reinterpret_cast <const __m128i*>(bsrc));
1781+ __m256i vf16_bits;
1782+ switch (ftype) {
1783+ case ov::element::f8e4m3:
1784+ vf16_bits = f8e4m3tof16 (vf8);
1785+ break ;
1786+ case ov::element::f8e5m2:
1787+ vf16_bits = f8e5m2tof16 (vf8);
1788+ break ;
1789+ case ov::element::f8e8m0:
1790+ vf16_bits = f8e8m0tof16 (vf8);
1791+ break ;
1792+ default :
1793+ NPUW_ASSERT (false );
1794+ return ;
1795+ }
1796+
1797+ // Split into two 128-bit halves (each holds 8 f16)
1798+ __m128i h_lo = _mm256_castsi256_si128 (vf16_bits);
1799+ __m128i h_hi = _mm256_extracti128_si256 (vf16_bits, 1 );
1800+
1801+ // Convert to float32
1802+ __m256 f_lo = _mm256_cvtph_ps (h_lo);
1803+ __m256 f_hi = _mm256_cvtph_ps (h_hi);
1804+
1805+ // Apply scale
1806+ __m256 svec = _mm256_set1_ps (scale_val);
1807+ f_lo = _mm256_mul_ps (f_lo, svec);
1808+ f_hi = _mm256_mul_ps (f_hi, svec);
1809+
1810+ // Back to f16
1811+ __m128i out_lo = _mm256_cvtps_ph (f_lo, _MM_FROUND_TO_NEAREST_INT);
1812+ __m128i out_hi = _mm256_cvtps_ph (f_hi, _MM_FROUND_TO_NEAREST_INT);
1813+
1814+ _mm_storeu_si128 (reinterpret_cast <__m128i*>(bdst), out_lo);
1815+ _mm_storeu_si128 (reinterpret_cast <__m128i*>(bdst + 8 ), out_hi);
1816+ };
1817+
1818+ // Work partitioning over scale dimension (similar pattern to other unpack_*_scale functions)
1819+ size_t stride = 1 ;
1820+ if (unpack_options.nPartitions ) {
1821+ if (unpack_options.bStrictPartitioning ) {
1822+ stride = (stotal + unpack_options.nPartitions - 1 ) / unpack_options.nPartitions ;
1823+ } else {
1824+ // Heuristic: ensure minimum intrinsic workload per thread.
1825+ // Require at least 2048 vector blocks per thread (if possible).
1826+ size_t vecBlocksPerScale = elemsPerScale / VEC;
1827+ if (vecBlocksPerScale == 0 )
1828+ vecBlocksPerScale = 1 ;
1829+ size_t minScaleStride = 2048 / vecBlocksPerScale;
1830+ if (minScaleStride == 0 )
1831+ minScaleStride = 1 ;
1832+ size_t minPartitions = stotal / minScaleStride;
1833+ if (minPartitions == 0 )
1834+ minPartitions = 1 ;
1835+ minPartitions = std::min (minPartitions, unpack_options.nPartitions );
1836+ stride = stotal / minPartitions;
1837+ if (stride == 0 )
1838+ stride = 1 ;
1839+ }
1840+ }
1841+ const size_t numWork = (stotal + stride - 1 ) / stride;
1842+
1843+ auto unpack_body = [&](size_t workIndex) {
1844+ size_t start = workIndex * stride;
1845+ size_t end = std::min (stotal, start + stride);
1846+ for (size_t s = start; s < end; ++s) {
1847+ const float scale_val = scl[s];
1848+ const uint8_t * src_scale_base = src + s * elemsPerScale;
1849+ uint16_t * dst_scale_base = dst + s * elemsPerScale;
1850+
1851+ size_t vecBlocks = elemsPerScale / VEC;
1852+ size_t tail = elemsPerScale % VEC;
1853+
1854+ // Vector path
1855+ for (size_t b = 0 ; b < vecBlocks; ++b) {
1856+ convert_block (src_scale_base + b * VEC, dst_scale_base + b * VEC, scale_val);
1857+ }
1858+
1859+ // Tail (scalar fallback)
1860+ if (tail) {
1861+ size_t offset = vecBlocks * VEC;
1862+ for (size_t t = 0 ; t < tail; ++t) {
1863+ uint8_t v = src_scale_base[offset + t];
1864+ uint16_t h;
1865+ // Lossy direct mapping for tail (no special cases like NaN/Inf)
1866+ if (ftype == ov::element::f8e4m3) {
1867+ uint8_t sign = (v & 0x80 ) >> 7 ;
1868+ uint8_t exp = (v & 0x78 ) >> 3 ;
1869+ uint8_t man = (v & 0x07 );
1870+ int16_t exp16 = exp + 8 ;
1871+ h = static_cast <uint16_t >((sign << 15 ) | (exp16 << 10 ) | (man << 7 ));
1872+ } else if (ftype == ov::element::f8e5m2) {
1873+ uint8_t sign = (v & 0x80 ) >> 7 ;
1874+ uint8_t exp = (v & 0x7C ) >> 2 ;
1875+ uint8_t man = (v & 0x03 );
1876+ h = static_cast <uint16_t >((sign << 15 ) | (exp << 10 ) | (man << 8 ));
1877+ } else { // f8e8m0
1878+ uint8_t sign = (v & 0x80 ) >> 7 ;
1879+ uint8_t exp = (v & 0x7F );
1880+ int16_t exp16 = static_cast <int16_t >(exp) - 112 ;
1881+ h = static_cast <uint16_t >((sign << 15 ) | ((exp16 & 0x1F ) << 10 ));
1882+ }
1883+ // Convert single f16 -> f32 -> scale -> f16
1884+ __m128i hvec = _mm_cvtsi32_si128 (h);
1885+ __m256 f32v = _mm256_cvtph_ps (hvec);
1886+ float fval = _mm_cvtss_f32 (_mm256_castps256_ps128 (f32v)) * scale_val;
1887+ __m256 scaled = _mm256_set1_ps (fval);
1888+ __m128i out_h = _mm256_cvtps_ph (scaled, _MM_FROUND_TO_NEAREST_INT);
1889+ dst_scale_base[offset + t] = static_cast <uint16_t >(_mm_cvtsi128_si32 (out_h));
1890+ }
1891+ }
1892+ }
1893+ };
1894+
1895+ if (unpack_options.bUseOvParallelFor ) {
1896+ ov::parallel_for (numWork, [&](size_t wi) {
1897+ unpack_body (wi);
1898+ });
1899+ } else {
1900+ for (size_t wi = 0 ; wi < numWork; ++wi) {
1901+ unpack_body (wi);
1902+ }
1903+ }
1904+ #else
1905+ OPENVINO_THROW (" AVX2 support is necessary but it's not enabled!" );
1906+ #endif
1907+ }
0 commit comments