1616"""Testing model conversion for a few gen-ai models."""
1717
1818import ai_edge_torch
19- from ai_edge_torch import config as ai_edge_config
2019from ai_edge_torch .generative .examples .amd_llama_135m import amd_llama_135m
2120from ai_edge_torch .generative .examples .gemma import gemma1
2221from ai_edge_torch .generative .examples .gemma import gemma2
@@ -91,35 +90,35 @@ def _test_model(self, config, model, signature_name, atol, rtol):
9190 )
9291
9392 @googletest .skipIf (
94- ai_edge_config . Config . use_torch_xla ,
95- reason = "tests with custom ops are not supported on oss" ,
93+ ai_edge_torch . config . in_oss ,
94+ reason = "tests with custom ops are not supported in oss" ,
9695 )
9796 def test_gemma1 (self ):
9897 config = gemma1 .get_fake_model_config ()
9998 pytorch_model = gemma1 .Gemma1 (config ).eval ()
10099 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-3 , rtol = 1e-5 )
101100
102101 @googletest .skipIf (
103- ai_edge_config . Config . use_torch_xla ,
104- reason = "tests with custom ops are not supported on oss" ,
102+ ai_edge_torch . config . in_oss ,
103+ reason = "tests with custom ops are not supported in oss" ,
105104 )
106105 def test_gemma2 (self ):
107106 config = gemma2 .get_fake_model_config ()
108107 pytorch_model = gemma2 .Gemma2 (config ).eval ()
109108 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-4 , rtol = 1e-5 )
110109
111110 @googletest .skipIf (
112- ai_edge_config . Config . use_torch_xla ,
113- reason = "tests with custom ops are not supported on oss" ,
111+ ai_edge_torch . config . in_oss ,
112+ reason = "tests with custom ops are not supported in oss" ,
114113 )
115114 def test_llama (self ):
116115 config = llama .get_fake_model_config ()
117116 pytorch_model = llama .Llama (config ).eval ()
118117 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-3 , rtol = 1e-5 )
119118
120119 @googletest .skipIf (
121- ai_edge_config . Config . use_torch_xla ,
122- reason = "tests with custom ops are not supported on oss" ,
120+ ai_edge_torch . config . in_oss ,
121+ reason = "tests with custom ops are not supported in oss" ,
123122 )
124123 def test_phi2 (self ):
125124 config = phi2 .get_fake_model_config ()
@@ -128,53 +127,53 @@ def test_phi2(self):
128127 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-3 , rtol = 1e-5 )
129128
130129 @googletest .skipIf (
131- ai_edge_config . Config . use_torch_xla ,
132- reason = "tests with custom ops are not supported on oss" ,
130+ ai_edge_torch . config . in_oss ,
131+ reason = "tests with custom ops are not supported in oss" ,
133132 )
134133 def test_phi3 (self ):
135134 config = phi3 .get_fake_model_config ()
136135 pytorch_model = phi3 .Phi3_5Mini (config ).eval ()
137136 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-5 , rtol = 1e-5 )
138137
139138 @googletest .skipIf (
140- ai_edge_config . Config . use_torch_xla ,
141- reason = "tests with custom ops are not supported on oss" ,
139+ ai_edge_torch . config . in_oss ,
140+ reason = "tests with custom ops are not supported in oss" ,
142141 )
143142 def test_smollm (self ):
144143 config = smollm .get_fake_model_config ()
145144 pytorch_model = smollm .SmolLM (config ).eval ()
146145 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-4 , rtol = 1e-5 )
147146
148147 @googletest .skipIf (
149- ai_edge_config . Config . use_torch_xla ,
150- reason = "tests with custom ops are not supported on oss" ,
148+ ai_edge_torch . config . in_oss ,
149+ reason = "tests with custom ops are not supported in oss" ,
151150 )
152151 def test_openelm (self ):
153152 config = openelm .get_fake_model_config ()
154153 pytorch_model = openelm .OpenELM (config ).eval ()
155154 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-4 , rtol = 1e-5 )
156155
157156 @googletest .skipIf (
158- ai_edge_config . Config . use_torch_xla ,
159- reason = "tests with custom ops are not supported on oss" ,
157+ ai_edge_torch . config . in_oss ,
158+ reason = "tests with custom ops are not supported in oss" ,
160159 )
161160 def test_qwen (self ):
162161 config = qwen .get_fake_model_config ()
163162 pytorch_model = qwen .Qwen (config ).eval ()
164163 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-3 , rtol = 1e-5 )
165164
166165 @googletest .skipIf (
167- ai_edge_config . Config . use_torch_xla ,
168- reason = "tests with custom ops are not supported on oss" ,
166+ ai_edge_torch . config . in_oss ,
167+ reason = "tests with custom ops are not supported in oss" ,
169168 )
170169 def test_amd_llama_135m (self ):
171170 config = amd_llama_135m .get_fake_model_config ()
172171 pytorch_model = amd_llama_135m .AmdLlama (config ).eval ()
173172 self ._test_model (config , pytorch_model , "prefill" , atol = 1e-5 , rtol = 1e-5 )
174173
175174 @googletest .skipIf (
176- ai_edge_config . Config . use_torch_xla ,
177- reason = "tests with custom ops are not supported on oss" ,
175+ ai_edge_torch . config . in_oss ,
176+ reason = "tests with custom ops are not supported in oss" ,
178177 )
179178 def disabled_test_paligemma (self ):
180179 config = paligemma .get_fake_model_config ()
@@ -222,8 +221,8 @@ def disabled_test_paligemma(self):
222221 )
223222
224223 @googletest .skipIf (
225- ai_edge_config . Config . use_torch_xla ,
226- reason = "tests with custom ops are not supported on oss" ,
224+ ai_edge_torch . config . in_oss ,
225+ reason = "tests with custom ops are not supported in oss" ,
227226 )
228227 def test_stable_diffusion_clip (self ):
229228 config = sd_clip .get_fake_model_config ()
@@ -254,8 +253,8 @@ def test_stable_diffusion_clip(self):
254253 )
255254
256255 @googletest .skipIf (
257- ai_edge_config . Config . use_torch_xla ,
258- reason = "tests with custom ops are not supported on oss" ,
256+ ai_edge_torch . config . in_oss ,
257+ reason = "tests with custom ops are not supported in oss" ,
259258 )
260259 def test_stable_diffusion_diffusion (self ):
261260 config = sd_diffusion .get_fake_model_config (2 )
@@ -296,8 +295,8 @@ def test_stable_diffusion_diffusion(self):
296295 )
297296
298297 @googletest .skipIf (
299- ai_edge_config . Config . use_torch_xla ,
300- reason = "tests with custom ops are not supported on oss" ,
298+ ai_edge_torch . config . in_oss ,
299+ reason = "tests with custom ops are not supported in oss" ,
301300 )
302301 def test_stable_diffusion_decoder (self ):
303302 config = sd_decoder .get_fake_model_config ()
0 commit comments