From 80a056d57148cf14cdfe1c1f4fef5eca22d1e4fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20du=20Hamel?= Date: Wed, 8 Jan 2025 21:04:14 +0100 Subject: [PATCH] refactor fp8 + add e3m4 (fn) --- model.cpp | 151 ++++++++++++++++++++++++++++-------------------------- model.h | 5 +- 2 files changed, 83 insertions(+), 73 deletions(-) diff --git a/model.cpp b/model.cpp index dcbaae5bc..fae38df9b 100644 --- a/model.cpp +++ b/model.cpp @@ -609,87 +609,72 @@ float bf16_to_f32(uint16_t bfloat16) { return *reinterpret_cast(&val_bits); } -uint16_t f8_e4m3_to_f16(uint8_t f8) { - // do we need to support uz? - - const uint32_t exponent_bias = 7; - if (f8 == 0xff) { - return ggml_fp32_to_fp16(-NAN); - } else if (f8 == 0x7f) { - return ggml_fp32_to_fp16(NAN); +uint16_t f8_e3m4_to_f16(uint8_t fp8) { + if ((fp8 & 0x7F) == 0 || (fp8 & 0x7F) == 0x7F) { + // +/- 0 or NaN + return static_cast(fp8) << 8; } + const uint8_t exponent_bias = 0x3; // 2^(3-1)-1 + const uint8_t f16_bias = 0xF; // 2^(5-1)-1 + const int mantissa_bits = 4; + const uint8_t mantissa_max = 0xF; // 2^4-1 - uint32_t sign = f8 & 0x80; - uint32_t exponent = (f8 & 0x78) >> 3; - uint32_t mantissa = f8 & 0x07; - uint32_t result = sign << 24; - if (exponent == 0) { - if (mantissa > 0) { - exponent = 0x7f - exponent_bias; - - // yes, 2 times - if ((mantissa & 0x04) == 0) { - mantissa &= 0x03; - mantissa <<= 1; - exponent -= 1; - } - if ((mantissa & 0x04) == 0) { - mantissa &= 0x03; - mantissa <<= 1; - exponent -= 1; - } + uint8_t sign = (fp8 >> 7) & 0x1; + uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits); + uint8_t mantissa = fp8 & mantissa_max; - result |= (mantissa & 0x03) << 21; - result |= exponent << 23; + uint16_t fp16_sign = sign << 15; + uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias)); + if (exponent == 0) { + // subnormal numbers + fp16_exponent++; + // mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0 + while (!(mantissa >> mantissa_bits)) { + mantissa <<= 1; + fp16_exponent--; } - } else { - result |= mantissa << 20; - exponent += 0x7f - exponent_bias; - result |= exponent << 23; + mantissa &= mantissa_max; } + uint16_t fp16_mantissa = mantissa << 6; - return ggml_fp32_to_fp16(*reinterpret_cast(&result)); + return fp16_sign | fp16_exponent << 10 | fp16_mantissa; } -uint16_t f8_e5m2_to_f16(uint8_t fp8) { - uint8_t sign = (fp8 >> 7) & 0x1; - uint8_t exponent = (fp8 >> 2) & 0x1F; - uint8_t mantissa = fp8 & 0x3; - - uint16_t fp16_sign = sign << 15; - uint16_t fp16_exponent; - uint16_t fp16_mantissa; - - if (exponent == 0 && mantissa == 0) { // zero - return fp16_sign; +uint16_t f8_e4m3_to_f16(uint8_t fp8) { + // do we need to support uz? + if ((fp8 & 0x7F) == 0 || (fp8 & 0x7F) == 0x7F) { + // +/- 0 or NaN + return static_cast(fp8) << 8; } + const uint8_t exponent_bias = 0x7; // 2^(4-1)-1 + const uint8_t f16_bias = 0xF; // 2^(5-1)-1 + const int mantissa_bits = 3; + const uint8_t mantissa_max = 0x7; // 2^3-1 - if (exponent == 0x1F) { // NAN and INF - fp16_exponent = 0x1F; - fp16_mantissa = mantissa ? (mantissa << 8) : 0; - return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; - } + uint8_t sign = (fp8 >> 7) & 0x1; + uint8_t exponent = (fp8 >> mantissa_bits) & (0x7F >> mantissa_bits); + uint8_t mantissa = fp8 & mantissa_max; - if (exponent == 0) { // subnormal numbers - fp16_exponent = 0; - fp16_mantissa = (mantissa << 8); - return fp16_sign | fp16_mantissa; + uint16_t fp16_sign = sign << 15; + uint16_t fp16_exponent = (exponent + (f16_bias - exponent_bias)); + if (exponent == 0) { + // subnormal numbers + fp16_exponent++; + // mantissa != 0 because (fp8 & 0x7F) != 0 && exponent == 0 + while (!(mantissa >> mantissa_bits)) { + mantissa <<= 1; + fp16_exponent--; + } + mantissa &= mantissa_max; } + uint16_t fp16_mantissa = mantissa << 7; - // normal numbers - int16_t true_exponent = (int16_t)exponent - 15 + 15; - if (true_exponent <= 0) { - fp16_exponent = 0; - fp16_mantissa = (mantissa << 8); - } else if (true_exponent >= 0x1F) { - fp16_exponent = 0x1F; - fp16_mantissa = 0; - } else { - fp16_exponent = (uint16_t)true_exponent; - fp16_mantissa = mantissa << 8; - } + return fp16_sign | fp16_exponent << 10 | fp16_mantissa; +} - return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; +uint16_t f8_e5m2_to_f16(uint8_t fp8) { + // do we need to support fnuz? + return static_cast(fp8) << 8; } void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) { @@ -699,6 +684,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) { } } +void f8_e3m4_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { + // support inplace op + for (int64_t i = n - 1; i >= 0; i--) { + dst[i] = f8_e3m4_to_f16(src[i]); + } +} + void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { // support inplace op for (int64_t i = n - 1; i >= 0; i--) { @@ -946,6 +938,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) { ttype = GGML_TYPE_F32; } else if (dtype == "F32") { ttype = GGML_TYPE_F32; + } else if (dtype == "F8_E3M4") { + ttype = GGML_TYPE_F16; } else if (dtype == "F8_E4M3") { ttype = GGML_TYPE_F16; } else if (dtype == "F8_E5M2") { @@ -1059,6 +1053,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const if (dtype == "BF16") { tensor_storage.is_bf16 = true; GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E3M4") { + tensor_storage.is_f8_e3m4 = true; + // f8 -> f16 + GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); } else if (dtype == "F8_E4M3") { tensor_storage.is_f8_e4m3 = true; // f8 -> f16 @@ -1461,10 +1459,10 @@ SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight, input_block_weight; bool input_block_checked = false; - bool has_multiple_encoders = false; - bool is_unet = false; + bool has_multiple_encoders = false; + bool is_unet = false; - bool is_xl = false; + bool is_xl = false; bool is_flux = false; #define found_family (is_xl || is_flux) @@ -1481,7 +1479,7 @@ SDVersion ModelLoader::get_sd_version() { } if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { is_unet = true; - if(has_multiple_encoders){ + if (has_multiple_encoders) { is_xl = true; if (input_block_checked) { break; @@ -1490,7 +1488,7 @@ SDVersion ModelLoader::get_sd_version() { } if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { has_multiple_encoders = true; - if(is_unet){ + if (is_unet) { is_xl = true; if (input_block_checked) { break; @@ -1779,6 +1777,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.is_bf16) { // inplace op bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e3m4) { + // inplace op + f8_e3m4_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); @@ -1793,6 +1794,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.is_bf16) { // inplace op bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e3m4) { + // inplace op + f8_e3m4_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); @@ -1811,6 +1815,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend if (tensor_storage.is_bf16) { // inplace op bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e3m4) { + // inplace op + f8_e3m4_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); diff --git a/model.h b/model.h index 95bbf1da2..f5cc1cfe2 100644 --- a/model.h +++ b/model.h @@ -89,6 +89,7 @@ struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; bool is_bf16 = false; + bool is_f8_e3m4 = false; bool is_f8_e4m3 = false; bool is_f8_e5m2 = false; int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; @@ -120,7 +121,7 @@ struct TensorStorage { } int64_t nbytes_to_read() const { - if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) { + if (is_bf16 || is_f8_e3m4 || is_f8_e4m3 || is_f8_e5m2) { return nbytes() / 2; } else { return nbytes(); @@ -168,6 +169,8 @@ struct TensorStorage { const char* type_name = ggml_type_name(type); if (is_bf16) { type_name = "bf16"; + } else if (is_f8_e3m4) { + type_name = "f8_e3m4"; } else if (is_f8_e4m3) { type_name = "f8_e4m3"; } else if (is_f8_e5m2) {