1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Testing torch.export compatibility for Mixtral models."""
16
+
17
+ import unittest
18
+
19
+ import torch
20
+ import torch .export as te
21
+
22
+ from transformers import MixtralConfig
23
+ from transformers .models .mixtral .modeling_mixtral import MixtralSparseMoeBlock
24
+ from transformers .testing_utils import require_torch , torch_device
25
+
26
+
27
+ @require_torch
28
+ class MixtralTorchExportTest (unittest .TestCase ):
29
+ """Test torch.export compatibility for Mixtral MoE components."""
30
+
31
+ def setUp (self ):
32
+ """Set up test configuration."""
33
+ self .config = MixtralConfig (
34
+ hidden_size = 128 ,
35
+ intermediate_size = 256 ,
36
+ num_local_experts = 8 ,
37
+ num_experts_per_tok = 2 ,
38
+ router_jitter_noise = 0.0 ,
39
+ )
40
+
41
+ def test_moe_block_torch_export (self ):
42
+ """Test that MixtralSparseMoeBlock can be exported with torch.export."""
43
+ # Create MoE block
44
+ moe_block = MixtralSparseMoeBlock (self .config )
45
+ moe_block .eval ()
46
+
47
+ # Move to meta device for export testing
48
+ moe_block = moe_block .to ("meta" )
49
+
50
+ # Create test input
51
+ batch_size , seq_len = 2 , 8
52
+ hidden_states = torch .randn (
53
+ batch_size , seq_len , self .config .hidden_size ,
54
+ device = "meta"
55
+ )
56
+
57
+ # Test torch.export - should not raise GuardOnDataDependentSymNode error
58
+ try :
59
+ exported_program = te .export (
60
+ moe_block ,
61
+ args = (hidden_states ,),
62
+ kwargs = {},
63
+ strict = False
64
+ )
65
+ # If export succeeds, the test passes
66
+ self .assertIsNotNone (exported_program )
67
+ except Exception as e :
68
+ # Check if it's the specific error we're trying to avoid
69
+ error_msg = str (e )
70
+ if "GuardOnDataDependentSymNode" in error_msg or "nonzero" in error_msg .lower ():
71
+ self .fail (
72
+ f"torch.export failed with data-dependent operation error: { error_msg } \n "
73
+ "This suggests the .nonzero() fix is not working properly."
74
+ )
75
+ else :
76
+ # Re-raise other unexpected errors
77
+ raise
78
+
79
+ def test_moe_block_functionality (self ):
80
+ """Test that MoE block maintains correct functionality after the fix."""
81
+ # Create MoE block
82
+ moe_block = MixtralSparseMoeBlock (self .config )
83
+ moe_block .eval ()
84
+
85
+ # Create test input
86
+ batch_size , seq_len = 2 , 4
87
+ hidden_states = torch .randn (batch_size , seq_len , self .config .hidden_size )
88
+
89
+ # Forward pass
90
+ with torch .no_grad ():
91
+ output , router_logits = moe_block (hidden_states )
92
+
93
+ # Verify output shapes
94
+ self .assertEqual (output .shape , hidden_states .shape )
95
+ self .assertEqual (
96
+ router_logits .shape ,
97
+ (batch_size * seq_len , self .config .num_local_experts )
98
+ )
99
+
100
+ # Verify that outputs are not all zeros (computation happened)
101
+ self .assertFalse (torch .allclose (output , torch .zeros_like (output )))
102
+
103
+ # Test with different input to ensure different outputs
104
+ hidden_states2 = torch .randn (batch_size , seq_len , self .config .hidden_size )
105
+ with torch .no_grad ():
106
+ output2 , _ = moe_block (hidden_states2 )
107
+
108
+ # Outputs should be different for different inputs
109
+ self .assertFalse (torch .allclose (output , output2 ))
110
+
111
+ def test_moe_block_export_with_different_configs (self ):
112
+ """Test torch.export with various expert configurations."""
113
+ test_configs = [
114
+ # (num_experts, top_k)
115
+ (4 , 2 ),
116
+ (8 , 2 ),
117
+ (16 , 2 ),
118
+ (8 , 4 ),
119
+ ]
120
+
121
+ for num_experts , top_k in test_configs :
122
+ with self .subTest (num_experts = num_experts , top_k = top_k ):
123
+ config = MixtralConfig (
124
+ hidden_size = 64 ,
125
+ intermediate_size = 128 ,
126
+ num_local_experts = num_experts ,
127
+ num_experts_per_tok = top_k ,
128
+ router_jitter_noise = 0.0 ,
129
+ )
130
+
131
+ moe_block = MixtralSparseMoeBlock (config )
132
+ moe_block .eval ()
133
+ moe_block = moe_block .to ("meta" )
134
+
135
+ hidden_states = torch .randn (1 , 4 , config .hidden_size , device = "meta" )
136
+
137
+ # Should export without errors
138
+ try :
139
+ exported_program = te .export (
140
+ moe_block ,
141
+ args = (hidden_states ,),
142
+ kwargs = {},
143
+ strict = False
144
+ )
145
+ self .assertIsNotNone (exported_program )
146
+ except Exception as e :
147
+ if "GuardOnDataDependentSymNode" in str (e ):
148
+ self .fail (f"Export failed for config ({ num_experts } , { top_k } ): { e } " )
149
+ else :
150
+ raise
151
+
152
+
153
+ if __name__ == "__main__" :
154
+ unittest .main ()
0 commit comments