Skip to content

proper fp8 types lowering from ml_dtypes #164895

@rrrrice

Description

@rrrrice

example mlir

module {
  func.func @extend_f8_to_f32(%arg0: f8E5M2) -> f32 {
    %0 = arith.extf %arg0 : f8E5M2 to f32
    return %0 : f32
  }
  func.func @truncate_f32_to_f8(%arg0: f32) -> f8E5M2 {
    %0 = arith.truncf %arg0 : f32 to f8E5M2
    return %0 : f8E5M2
  }
}

fp8 gets lowered into i8

backend = LLVMJITBackend()
compiled = backend.compile(
    module_finished,
    kernel_name="extend_f8_to_f32",
    pipeline=Pipeline().lower_to_llvm(),
)
invoker = backend.load(compiled)
input_val = ml_dtypes.float8_e5m2(0.5)
result = invoker.extend_f8_to_f32(input_val)
mlir.extras.runtime.passes.MlirCompilerError: Lowering IR failed with the following diagnostics:

********************************************************************************
Failure while executing pass pipeline:
error: unknown: 'llvm.fpext' op operand #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'i8'
note: unknown: see current operation: %1 = "llvm.fpext"(%arg0) : (i8) -> f32
error: unknown: 'llvm.fptrunc' op result #0 must be floating point LLVM type or LLVM dialect-compatible vector of floating point LLVM type, but got 'i8'
note: unknown: see current operation: %0 = "llvm.fptrunc"(%arg0) : (f32) -> i8
********************************************************************************

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions