Skip to content

Commit 972b4ff

Browse files
authored
Fix Unit test failure for JAX/Flax version update (#264)
1 parent 073d831 commit 972b4ff

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
from ..models.attention_flax import FlaxWanAttention
3838
from maxdiffusion.pyconfig import HyperParameters
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
40+
import qwix
41+
import flax
42+
43+
flax.config.update('flax_always_shard_variable', False)
44+
RealQtRule = qwix.QtRule
4045

4146

4247
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
@@ -282,6 +287,10 @@ def test_get_qt_provider(self, mock_qt_rule):
282287
"""
283288
Tests the provider logic for all config branches.
284289
"""
290+
def create_real_rule_instance(*args, **kwargs):
291+
return RealQtRule(*args, **kwargs)
292+
mock_qt_rule.side_effect = create_real_rule_instance
293+
285294
# Case 1: Quantization disabled
286295
config_disabled = Mock(spec=HyperParameters)
287296
config_disabled.use_qwix_quantization = False
@@ -301,7 +310,7 @@ def test_get_qt_provider(self, mock_qt_rule):
301310
config_fp8 = Mock(spec=HyperParameters)
302311
config_fp8.use_qwix_quantization = True
303312
config_fp8.quantization = "fp8"
304-
config_int8.qwix_module_path = ".*"
313+
config_fp8.qwix_module_path = ".*"
305314
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
306315
self.assertIsNotNone(provider_fp8)
307316
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated"))
@@ -312,7 +321,7 @@ def test_get_qt_provider(self, mock_qt_rule):
312321
config_fp8_full.use_qwix_quantization = True
313322
config_fp8_full.quantization = "fp8_full"
314323
config_fp8_full.quantization_calibration_method = "absmax"
315-
config_int8.qwix_module_path = ".*"
324+
config_fp8_full.qwix_module_path = ".*"
316325
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
317326
self.assertIsNotNone(provider_fp8_full)
318327
expected_calls = [
@@ -361,6 +370,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
361370
mock_config.quantization = "fp8_full"
362371
mock_config.qwix_module_path = ".*"
363372
mock_config.per_device_batch_size = 1
373+
mock_config.quantization_calibration_method = "absmax"
364374

365375
mock_model = Mock(spec=WanModel)
366376
mock_pipeline = Mock()

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@
4646
from ..models.wan.wan_utils import load_wan_vae
4747
from ..utils import load_video
4848
from ..video_processor import VideoProcessor
49+
import flax
4950

5051
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
5152

5253
CACHE_T = 2
5354

54-
55+
flax.config.update('flax_always_shard_variable', False)
5556
class TorchWanRMS_norm(nn.Module):
5657
r"""
5758
A custom RMS normalization layer.

0 commit comments

Comments
 (0)