3737from ..models .attention_flax import FlaxWanAttention
3838from maxdiffusion .pyconfig import HyperParameters
3939from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
40+ import qwix
41+ import flax
42+
43+ flax .config .update ('flax_always_shard_variable' , False )
44+ RealQtRule = qwix .QtRule
4045
4146
4247IN_GITHUB_ACTIONS = os .getenv ("GITHUB_ACTIONS" ) == "true"
@@ -282,6 +287,10 @@ def test_get_qt_provider(self, mock_qt_rule):
282287 """
283288 Tests the provider logic for all config branches.
284289 """
290+ def create_real_rule_instance (* args , ** kwargs ):
291+ return RealQtRule (* args , ** kwargs )
292+ mock_qt_rule .side_effect = create_real_rule_instance
293+
285294 # Case 1: Quantization disabled
286295 config_disabled = Mock (spec = HyperParameters )
287296 config_disabled .use_qwix_quantization = False
@@ -301,7 +310,7 @@ def test_get_qt_provider(self, mock_qt_rule):
301310 config_fp8 = Mock (spec = HyperParameters )
302311 config_fp8 .use_qwix_quantization = True
303312 config_fp8 .quantization = "fp8"
304- config_int8 .qwix_module_path = ".*"
313+ config_fp8 .qwix_module_path = ".*"
305314 provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
306315 self .assertIsNotNone (provider_fp8 )
307316 mock_qt_rule .assert_called_once_with (module_path = ".*" , weight_qtype = jnp .float8_e4m3fn , act_qtype = jnp .float8_e4m3fn , op_names = ("dot_general" ,"einsum" , "conv_general_dilated" ))
@@ -312,7 +321,7 @@ def test_get_qt_provider(self, mock_qt_rule):
312321 config_fp8_full .use_qwix_quantization = True
313322 config_fp8_full .quantization = "fp8_full"
314323 config_fp8_full .quantization_calibration_method = "absmax"
315- config_int8 .qwix_module_path = ".*"
324+ config_fp8_full .qwix_module_path = ".*"
316325 provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
317326 self .assertIsNotNone (provider_fp8_full )
318327 expected_calls = [
@@ -361,6 +370,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
361370 mock_config .quantization = "fp8_full"
362371 mock_config .qwix_module_path = ".*"
363372 mock_config .per_device_batch_size = 1
373+ mock_config .quantization_calibration_method = "absmax"
364374
365375 mock_model = Mock (spec = WanModel )
366376 mock_pipeline = Mock ()
0 commit comments