@@ -1161,6 +1161,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor:
11611161 return op .CastLike (sampled , self )
11621162
11631163
1164+ @torch_op ("aten::bilinear" , trace_only = True )
11641165def aten_bilinear (
11651166 input1 : TensorType ,
11661167 input2 : TensorType ,
@@ -1169,7 +1170,23 @@ def aten_bilinear(
11691170) -> TensorType :
11701171 """bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor"""
11711172
1172- raise NotImplementedError ()
1173+ # Bilinear transformation: y = x1^T A x2 + b
1174+ # input1 shape: (..., in1_features)
1175+ # input2 shape: (..., in2_features)
1176+ # weight shape: (out_features, in1_features, in2_features)
1177+ # bias shape: (out_features) - optional
1178+ # output shape: (..., out_features)
1179+
1180+ # Use Einsum to compute the bilinear transformation
1181+ # "...i,oij,...j->...o" means:
1182+ # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1183+ result = op .Einsum (input1 , weight , input2 , equation = "...i,oij,...j->...o" )
1184+
1185+ # Add bias if provided
1186+ if bias is not None :
1187+ result = op .Add (result , bias )
1188+
1189+ return result
11731190
11741191
11751192def aten_binary_cross_entropy_with_logits (
@@ -7284,7 +7301,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
72847301
72857302@torch_op ("aten::scalar_tensor" , trace_only = True )
72867303def aten_scalar_tensor (
7287- s : float ,
7304+ s : TensorType ,
72887305 dtype : int = FLOAT .dtype ,
72897306 layout : str = "" ,
72907307 device : str = "" ,
@@ -7322,17 +7339,35 @@ def aten_scalar_tensor_complex(
73227339 return result
73237340
73247341
7325- @torch_op (( "aten::scatter.value" , "aten::scatter. src") , trace_only = True )
7326- def aten_scatter (
7327- self : TReal ,
7342+ @torch_op ("aten::scatter.src" , trace_only = True )
7343+ def aten_scatter_src (
7344+ self : TTensor ,
73287345 dim : int , # we have to use int here because ScatterElements() will use this attribute
73297346 index : TInt ,
7330- src : TReal ,
7331- ) -> TReal :
7332- """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7347+ src : TTensor ,
7348+ ) -> TTensor :
7349+ """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7350+ if len (index .shape ) == 0 :
7351+ index = op .Unsqueeze (index , [0 ])
7352+ if len (src .shape ) == 0 :
7353+ src = op .Unsqueeze (src , [0 ])
7354+ return op .ScatterElements (self , index , src , axis = dim )
73337355
7334- update = op .Expand (src , op .Shape (index ))
7335- return op .ScatterElements (self , index , update , axis = dim )
7356+
7357+ @torch_op ("aten::scatter.value" , trace_only = True )
7358+ def aten_scatter_value (
7359+ self : TTensor ,
7360+ dim : int , # we have to use int here because ScatterElements() will use this attribute
7361+ index : TInt ,
7362+ value : float ,
7363+ ) -> TTensor :
7364+ """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
7365+ # Ensure value is a scalar tensor and expand it to match index shape
7366+ if len (index .shape ) == 0 :
7367+ index = op .Unsqueeze (index , [0 ])
7368+ scalar_tensor = ir .tensor ([value ], dtype = self .dtype )
7369+ src = op .ConstantOfShape (op .Shape (index ), value = scalar_tensor )
7370+ return op .ScatterElements (self , index , src , axis = dim )
73367371
73377372
73387373@torch_op ("aten::scatter_add" , trace_only = True )
0 commit comments