Skip to content

Commit 073d831

Browse files
authored
Qwix (#259)
* qwix * fix commits * fix test
1 parent 230c460 commit 073d831

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ compile_topology_num_slices: -1 # Number of target slices, set to a positive int
313313
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
314314
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
315315
quantization_calibration_method: "absmax"
316+
qwix_module_path: ".*"
316317

317318
# Eval model on per eval_every steps. -1 means don't eval.
318319
eval_every: -1

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -243,34 +243,48 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
243243
return wan_vae, vae_cache
244244

245245
@classmethod
246-
def get_basic_config(cls, dtype):
246+
def get_basic_config(cls, dtype, config: HyperParameters):
247247
rules = [
248248
qwix.QtRule(
249-
module_path=".*", # Apply to all modules
249+
module_path=config.qwix_module_path,
250250
weight_qtype=dtype,
251251
act_qtype=dtype,
252+
op_names=("dot_general","einsum", "conv_general_dilated"),
252253
)
253254
]
254255
return rules
255256

256257
@classmethod
257-
def get_fp8_config(cls, quantization_calibration_method: str):
258+
def get_fp8_config(cls, config: HyperParameters):
258259
"""
259260
fp8 config rules with per-tensor calibration.
260261
FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api):
261262
The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice.
262263
"""
263264
rules = [
264265
qwix.QtRule(
265-
module_path=".*", # Apply to all modules
266+
module_path=config.qwix_module_path,
266267
weight_qtype=jnp.float8_e4m3fn,
267268
act_qtype=jnp.float8_e4m3fn,
269+
bwd_qtype=jnp.float8_e5m2,
270+
bwd_use_original_residuals=True,
271+
disable_channelwise_axes=True, # per_tensor calibration
272+
weight_calibration_method=config.quantization_calibration_method,
273+
act_calibration_method=config.quantization_calibration_method,
274+
bwd_calibration_method=config.quantization_calibration_method,
275+
op_names=("dot_general","einsum"),
276+
),
277+
qwix.QtRule(
278+
module_path=config.qwix_module_path,
279+
weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes
280+
act_qtype=jnp.float8_e4m3fn,
268281
bwd_qtype=jnp.float8_e4m3fn,
269282
bwd_use_original_residuals=True,
270283
disable_channelwise_axes=True, # per_tensor calibration
271-
weight_calibration_method=quantization_calibration_method,
272-
act_calibration_method=quantization_calibration_method,
273-
bwd_calibration_method=quantization_calibration_method,
284+
weight_calibration_method=config.quantization_calibration_method,
285+
act_calibration_method=config.quantization_calibration_method,
286+
bwd_calibration_method=config.quantization_calibration_method,
287+
op_names=("conv_general_dilated"),
274288
)
275289
]
276290
return rules
@@ -281,14 +295,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]:
281295
if not getattr(config, "use_qwix_quantization", False):
282296
return None
283297

284-
quantization_calibration_method = getattr(config, "quantization_calibration_method", "absmax")
285298
match config.quantization:
286299
case "int8":
287-
return qwix.QtProvider(cls.get_basic_config(jnp.int8))
300+
return qwix.QtProvider(cls.get_basic_config(jnp.int8, config))
288301
case "fp8":
289-
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn))
302+
return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config))
290303
case "fp8_full":
291-
return qwix.QtProvider(cls.get_fp8_config(quantization_calibration_method))
304+
return qwix.QtProvider(cls.get_fp8_config(config))
292305
return None
293306

294307
@classmethod

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import jax.numpy as jnp
2020
import pytest
2121
import unittest
22-
from unittest.mock import Mock, patch
22+
from unittest.mock import Mock, patch, call
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
@@ -291,28 +291,43 @@ def test_get_qt_provider(self, mock_qt_rule):
291291
config_int8 = Mock(spec=HyperParameters)
292292
config_int8.use_qwix_quantization = True
293293
config_int8.quantization = "int8"
294+
config_int8.qwix_module_path = ".*"
294295
provider_int8 = WanPipeline.get_qt_provider(config_int8)
295296
self.assertIsNotNone(provider_int8)
296-
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8)
297+
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8, op_names=("dot_general","einsum", "conv_general_dilated"))
297298

298299
# Case 3: Quantization enabled, type 'fp8'
299300
mock_qt_rule.reset_mock()
300301
config_fp8 = Mock(spec=HyperParameters)
301302
config_fp8.use_qwix_quantization = True
302303
config_fp8.quantization = "fp8"
304+
config_int8.qwix_module_path = ".*"
303305
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
304306
self.assertIsNotNone(provider_fp8)
305-
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn)
307+
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated"))
306308

307309
# Case 4: Quantization enabled, type 'fp8_full'
308310
mock_qt_rule.reset_mock()
309311
config_fp8_full = Mock(spec=HyperParameters)
310312
config_fp8_full.use_qwix_quantization = True
311313
config_fp8_full.quantization = "fp8_full"
312314
config_fp8_full.quantization_calibration_method = "absmax"
315+
config_int8.qwix_module_path = ".*"
313316
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
314317
self.assertIsNotNone(provider_fp8_full)
315-
mock_qt_rule.assert_called_once_with(
318+
expected_calls = [
319+
call(module_path=".*", # Apply to all modules
320+
weight_qtype=jnp.float8_e4m3fn,
321+
act_qtype=jnp.float8_e4m3fn,
322+
bwd_qtype=jnp.float8_e5m2,
323+
bwd_use_original_residuals=True,
324+
disable_channelwise_axes=True, # per_tensor calibration
325+
weight_calibration_method=config_fp8_full.quantization_calibration_method,
326+
act_calibration_method=config_fp8_full.quantization_calibration_method,
327+
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
328+
op_names=("dot_general","einsum"),
329+
),
330+
call(
316331
module_path=".*", # Apply to all modules
317332
weight_qtype=jnp.float8_e4m3fn,
318333
act_qtype=jnp.float8_e4m3fn,
@@ -322,7 +337,10 @@ def test_get_qt_provider(self, mock_qt_rule):
322337
weight_calibration_method=config_fp8_full.quantization_calibration_method,
323338
act_calibration_method=config_fp8_full.quantization_calibration_method,
324339
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
325-
)
340+
op_names=("conv_general_dilated"),
341+
)
342+
]
343+
mock_qt_rule.assert_has_calls(expected_calls, any_order=True)
326344

327345
# Case 5: Invalid quantization type
328346
config_invalid = Mock(spec=HyperParameters)
@@ -341,6 +359,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
341359
mock_config = Mock(spec=HyperParameters)
342360
mock_config.use_qwix_quantization = True
343361
mock_config.quantization = "fp8_full"
362+
mock_config.qwix_module_path = ".*"
344363
mock_config.per_device_batch_size = 1
345364

346365
mock_model = Mock(spec=WanModel)

0 commit comments

Comments
 (0)