Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,63 @@ NIF(as_strided) {
TENSOR(mlx::core::as_strided(*t, shape, strides, offset, device));
}

// ============================================================================
// Quantization Operations (for 4-bit model support)
// ============================================================================

// quantized_matmul - Multiplies x with a quantized weight matrix w
// This is the key operation for efficient 4-bit inference
// MLX API: quantized_matmul(x, w, scales, biases, transpose, group_size, bits, stream)
NIF(quantized_matmul) {
TENSOR_PARAM(0, x); // Input tensor [batch, seq, hidden]
TENSOR_PARAM(1, w); // Quantized weights [out/8, in] (uint32 packed)
TENSOR_PARAM(2, scales); // Scales [out/group_size, in] (bfloat16)
TENSOR_PARAM(3, biases); // Biases [out/group_size, in] (bfloat16)
PARAM(4, bool, transpose);
PARAM(5, int, group_size);
PARAM(6, int, bits);
DEVICE_PARAM(7, device);

TENSOR(mlx::core::quantized_matmul(
*x, *w, *scales, *biases, transpose, group_size, bits, device));
}

// dequantize - Converts quantized weights back to float
// Useful for debugging and verification
// MLX API: dequantize(w, scales, biases, group_size, bits, stream)
NIF(dequantize) {
TENSOR_PARAM(0, w); // Quantized weights (uint32 packed)
TENSOR_PARAM(1, scales); // Scales (bfloat16)
TENSOR_PARAM(2, biases); // Biases (bfloat16)
PARAM(3, int, group_size);
PARAM(4, int, bits);
DEVICE_PARAM(5, device);

TENSOR(mlx::core::dequantize(*w, *scales, *biases, group_size, bits, device));
}

// quantize - Quantizes a float tensor to packed format
// Returns tuple of {weights, scales, biases}
// MLX API: quantize(w, group_size, bits, stream) -> tuple<array, array, array>
NIF(quantize) {
TENSOR_PARAM(0, w); // Float weights to quantize
PARAM(1, int, group_size);
PARAM(2, int, bits);
DEVICE_PARAM(3, device);

try {
auto [qw, scales, biases] = mlx::core::quantize(*w, group_size, bits, device);

ERL_NIF_TERM result_tuple[3];
result_tuple[0] = create_tensor_resource(env, qw);
result_tuple[1] = create_tensor_resource(env, scales);
result_tuple[2] = create_tensor_resource(env, biases);

return nx::nif::ok(env, enif_make_tuple3(env, result_tuple[0], result_tuple[1], result_tuple[2]));
}
CATCH()
}

static ErlNifFunc nif_funcs[] = {
{"strides", 1, strides},
{"as_strided", 5, as_strided},
Expand Down Expand Up @@ -1087,7 +1144,11 @@ static ErlNifFunc nif_funcs[] = {
{"max", 4, max},
{"min", 4, min},
{"clip", 4, clip},
{"tri_inv", 3, tri_inv}
{"tri_inv", 3, tri_inv},
// Quantization operations
{"quantized_matmul", 8, quantized_matmul},
{"dequantize", 6, dequantize},
{"quantize", 4, quantize}
};

// Update the NIF initialization
Expand Down
141 changes: 141 additions & 0 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,97 @@ defmodule EMLX do
defvalue scalar_type(tensor)
defvalue shape(tensor)

## Quantization operations (for 4-bit model support)

@doc """
Performs quantized matrix multiplication.

This is the key operation for efficient 4-bit inference. It multiplies `x` with
quantized weights `w` (packed as uint32), using scales and biases for
dequantization during the computation.

## Parameters
- `x` - Input tensor (e.g., {batch, seq, hidden})
- `w` - Quantized weights as uint32 (8 int4 values packed per uint32)
- `scales` - Per-group scale factors (bfloat16)
- `biases` - Per-group zero points (bfloat16)
- `transpose` - Whether to transpose weights (default: true)
- `group_size` - Number of weights per scale/bias group (default: 64)
- `bits` - Quantization bits (default: 4)
"""
@mlx_function {:quantized_matmul, 8}
def quantized_matmul(
{dev_x, ref_x} = _tensor_x,
{dev_w, ref_w} = _tensor_w,
{dev_s, ref_s} = _tensor_scales,
{dev_b, ref_b} = _tensor_biases,
transpose \\ true,
group_size \\ 64,
bits \\ 4
)
when is_tensor(dev_x, ref_x) and is_tensor(dev_w, ref_w) and
is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
device = merge_device(merge_device(dev_x, dev_w), merge_device(dev_s, dev_b))
mlx_device = mlx_device!(device, -1)

EMLX.NIF.quantized_matmul(ref_x, ref_w, ref_s, ref_b, transpose, group_size, bits, mlx_device)
|> unwrap_tensor!(device)
end

@doc """
Dequantizes packed weights to floating point.

Converts quantized weights back to their original floating point representation.
Useful for debugging and verification.

## Parameters
- `w` - Quantized weights as uint32 (packed int4 values)
- `scales` - Per-group scale factors
- `biases` - Per-group zero points
- `group_size` - Number of weights per group (default: 64)
- `bits` - Quantization bits (default: 4)
"""
@mlx_function {:dequantize, 6}
def dequantize(
{dev_w, ref_w} = _tensor_w,
{dev_s, ref_s} = _tensor_scales,
{dev_b, ref_b} = _tensor_biases,
group_size \\ 64,
bits \\ 4
)
when is_tensor(dev_w, ref_w) and is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
device = merge_device(dev_w, merge_device(dev_s, dev_b))
mlx_device = mlx_device!(device, -1)

EMLX.NIF.dequantize(ref_w, ref_s, ref_b, group_size, bits, mlx_device)
|> unwrap_tensor!(device)
end

@doc """
Quantizes a floating point tensor to packed format.

Returns a tuple of `{quantized_weights, scales, biases}` where:
- `quantized_weights` - Packed uint32 tensor (8 int4 values per uint32)
- `scales` - Per-group scale factors
- `biases` - Per-group zero points

## Parameters
- `w` - Float tensor to quantize
- `group_size` - Number of weights per group (default: 64)
- `bits` - Quantization bits (default: 4)
"""
@mlx_function {:quantize, 4}
def quantize({dev_w, ref_w} = _tensor_w, group_size \\ 64, bits \\ 4)
when is_tensor(dev_w, ref_w) do
device = dev_w
mlx_device = mlx_device!(device, -1)

{weights_ref, scales_ref, biases_ref} =
EMLX.NIF.quantize(ref_w, group_size, bits, mlx_device) |> unwrap!()

{{device, weights_ref}, {device, scales_ref}, {device, biases_ref}}
end

def to_blob({device, ref} = tensor) when is_tensor(device, ref) do
# Two-step to_blob: eval on main scheduler, then copy on dirty scheduler
eval(tensor)
Expand Down Expand Up @@ -323,6 +414,56 @@ defmodule EMLX do
defvalue item(tensor)
defvalue strides(tensor)

# ============================================================================
# Quantized Tensor Operations (Backend-Integrated)
# ============================================================================

@doc """
Creates a quantized Nx.Tensor with backend-level quantization options.

This creates an Nx.Tensor where the EMLX.Backend struct contains
quantization metadata. When this tensor is used in `Nx.dot`, the
backend automatically dispatches to `quantized_matmul`.

## Parameters

- `weight_ref` - EMLX device ref for packed uint32 weights
- `scales_ref` - EMLX device ref for per-group scale factors
- `biases_ref` - EMLX device ref for per-group zero points
- `original_shape` - Shape before quantization {out_features, in_features}

## Options

- `:bits` - Quantization bits (default: 4)
- `:group_size` - Weights per scale/bias group (default: 64)

## Example

# Quantize weights
{q_weight, scales, biases} = EMLX.quantize(weight_tensor, 64, 4)

# Create quantized Nx.Tensor
quantized = EMLX.quantized_tensor(q_weight, scales, biases, {512, 4096})

# Standard Nx.dot automatically uses quantized_matmul!
result = Nx.dot(input, quantized)
"""
def quantized_tensor(weight_ref, scales_ref, biases_ref, original_shape, opts \\ []) do
EMLX.Backend.quantized_tensor(weight_ref, scales_ref, biases_ref, original_shape, opts)
end

@doc """
Converts an EMLX device ref back to an Nx.Tensor.

## Example

result_ref = EMLX.some_operation(input)
result_tensor = EMLX.to_nx(result_ref)
"""
def to_nx({device, ref} = device_ref) when is_atom(device) and is_reference(ref) do
EMLX.Backend.to_nx(device_ref)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you're using this from outside of the Backend module, you should move the implementation here and delegate from there instead.

As this stands, it introduces a circular dependency between EMLX and EMLX.Backend

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good catch thank you, will do

end

@behaviour Nx.Defn.Compiler

@impl Nx.Defn.Compiler
Expand Down
Loading