@@ -126,50 +126,70 @@ def _impl_unstable_to_stable_fftn(self, gm):
126126
127127 return gm
128128
129- def _impl_unstable_to_stable_one_hot (self , gm ):
129+ def _impl_unstable_to_stable_special_logit (self , gm ):
130130 """
131- Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
131+ Convert torch._C._special.special_logit to torch.special.logit
132132 """
133- import torch .nn .functional as F
134-
135133 issue_nodes = (
136134 node
137135 for node in gm .graph .nodes
138136 if node .op == "call_function"
139137 if hasattr (node .target , "__module__" )
140- if node .target .__module__ == "torch._C._nn "
138+ if node .target .__module__ == "torch._C._special "
141139 if hasattr (node .target , "__name__" )
142- if node .target .__name__ == "one_hot "
140+ if node .target .__name__ == "special_logit "
143141 )
144142 for node in issue_nodes :
145- node .target = F . one_hot
143+ node .target = torch . special . logit
146144
147145 # Recompile the graph
148146 gm .recompile ()
149147
150148 return gm
151149
152- def _impl_unstable_to_stable_special_logit (self , gm ):
150+ # replace this line with modification code for task 116 (torch._C._linalg.linalg_vector_norm)
151+
152+ # replace this line with modification code for task 117 (torch._C._linalg.linalg_norm)
153+
154+ # replace this line with modification code for task 118 (torch._C._nn.softplus)
155+
156+ # replace this line with modification code for task 119 (torch._C._nn.one_hot)
157+
158+ def _impl_unstable_to_stable_one_hot (self , gm ):
153159 """
154- Convert torch._C._special.special_logit to torch.special.logit
160+ Convert torch._C._nn.one_hot to torch.nn.functional.one_hot
155161 """
162+ import torch .nn .functional as F
163+
156164 issue_nodes = (
157165 node
158166 for node in gm .graph .nodes
159167 if node .op == "call_function"
160168 if hasattr (node .target , "__module__" )
161- if node .target .__module__ == "torch._C._special "
169+ if node .target .__module__ == "torch._C._nn "
162170 if hasattr (node .target , "__name__" )
163- if node .target .__name__ == "special_logit "
171+ if node .target .__name__ == "one_hot "
164172 )
165173 for node in issue_nodes :
166- node .target = torch . special . logit
174+ node .target = F . one_hot
167175
168176 # Recompile the graph
169177 gm .recompile ()
170178
171179 return gm
172180
181+ # replace this line with modification code for task 121 (torch._C._set_grad_enabled)
182+
183+ # replace this line with modification code for task 122 (torch._C._log_api_usage_once)
184+
185+ # replace this line with modification code for task 123 (torch._C._nn.pad)
186+
187+ # replace this line with modification code for task 125 (torch._C._nn.gelu)
188+
189+ # replace this line with modification code for task 126 (torch._C._nn.scaled_dot_product_attention)
190+
191+ # replace this line with modification code for task 127 (torch._C._nn.linear)
192+
173193 def unstable_to_stable (self , gm ):
174194 methods = (
175195 name
0 commit comments