@@ -79,14 +79,14 @@ impl<T> EinSumOp<T> {
7979 /// # Returns
8080 ///
8181 /// An `EinSumAST<T>` node representing the contraction operation.
82- pub fn contraction (
83- lhs : CausalTensor < T > ,
84- rhs : CausalTensor < T > ,
82+ pub fn contraction < L : Into < CausalTensor < T > > , R : Into < CausalTensor < T > > > (
83+ lhs : L ,
84+ rhs : R ,
8585 lhs_axes : Vec < usize > ,
8686 rhs_axes : Vec < usize > ,
8787 ) -> EinSumAST < T > {
88- let lhs_leaf = EinSumOp :: tensor_source ( lhs) ;
89- let rhs_leaf = EinSumOp :: tensor_source ( rhs) ;
88+ let lhs_leaf = EinSumOp :: tensor_source ( lhs. into ( ) ) ;
89+ let rhs_leaf = EinSumOp :: tensor_source ( rhs. into ( ) ) ;
9090 EinSumAST :: with_children (
9191 EinSumOp :: Contraction { lhs_axes, rhs_axes } ,
9292 vec ! [ lhs_leaf, rhs_leaf] ,
@@ -105,8 +105,8 @@ impl<T> EinSumOp<T> {
105105 /// # Returns
106106 ///
107107 /// An `EinSumAST<T>` node representing the reduction operation.
108- pub fn reduction ( operand : CausalTensor < T > , axes : Vec < usize > ) -> EinSumAST < T > {
109- let operand_leaf = EinSumOp :: tensor_source ( operand) ;
108+ pub fn reduction < O : Into < CausalTensor < T > > > ( operand : O , axes : Vec < usize > ) -> EinSumAST < T > {
109+ let operand_leaf = EinSumOp :: tensor_source ( operand. into ( ) ) ;
110110 EinSumAST :: with_children ( EinSumOp :: Reduction { axes } , vec ! [ operand_leaf] )
111111 }
112112
@@ -122,9 +122,12 @@ impl<T> EinSumOp<T> {
122122 /// # Returns
123123 ///
124124 /// An `EinSumAST<T>` node representing the matrix multiplication operation.
125- pub fn mat_mul ( lhs : CausalTensor < T > , rhs : CausalTensor < T > ) -> EinSumAST < T > {
126- let lhs_leaf = EinSumOp :: tensor_source ( lhs) ;
127- let rhs_leaf = EinSumOp :: tensor_source ( rhs) ;
125+ pub fn mat_mul < L : Into < CausalTensor < T > > , R : Into < CausalTensor < T > > > (
126+ lhs : L ,
127+ rhs : R ,
128+ ) -> EinSumAST < T > {
129+ let lhs_leaf = EinSumOp :: tensor_source ( lhs. into ( ) ) ;
130+ let rhs_leaf = EinSumOp :: tensor_source ( rhs. into ( ) ) ;
128131 EinSumAST :: with_children ( EinSumOp :: MatMul , vec ! [ lhs_leaf, rhs_leaf] )
129132 }
130133
@@ -140,9 +143,12 @@ impl<T> EinSumOp<T> {
140143 /// # Returns
141144 ///
142145 /// An `EinSumAST<T>` node representing the dot product operation.
143- pub fn dot_prod ( lhs : CausalTensor < T > , rhs : CausalTensor < T > ) -> EinSumAST < T > {
144- let lhs_leaf = EinSumOp :: tensor_source ( lhs) ;
145- let rhs_leaf = EinSumOp :: tensor_source ( rhs) ;
146+ pub fn dot_prod < L : Into < CausalTensor < T > > , R : Into < CausalTensor < T > > > (
147+ lhs : L ,
148+ rhs : R ,
149+ ) -> EinSumAST < T > {
150+ let lhs_leaf = EinSumOp :: tensor_source ( lhs. into ( ) ) ;
151+ let rhs_leaf = EinSumOp :: tensor_source ( rhs. into ( ) ) ;
146152 EinSumAST :: with_children ( EinSumOp :: DotProd , vec ! [ lhs_leaf, rhs_leaf] )
147153 }
148154
@@ -159,8 +165,8 @@ impl<T> EinSumOp<T> {
159165 /// # Returns
160166 ///
161167 /// An `EinSumAST<T>` node representing the trace operation.
162- pub fn trace ( operand : CausalTensor < T > , axes1 : usize , axes2 : usize ) -> EinSumAST < T > {
163- let operand_leaf = EinSumOp :: tensor_source ( operand) ;
168+ pub fn trace < O : Into < CausalTensor < T > > > ( operand : O , axes1 : usize , axes2 : usize ) -> EinSumAST < T > {
169+ let operand_leaf = EinSumOp :: tensor_source ( operand. into ( ) ) ;
164170 EinSumAST :: with_children ( EinSumOp :: Trace { axes1, axes2 } , vec ! [ operand_leaf] )
165171 }
166172
@@ -176,9 +182,12 @@ impl<T> EinSumOp<T> {
176182 /// # Returns
177183 ///
178184 /// An `EinSumAST<T>` node representing the tensor product operation.
179- pub fn tensor_product ( lhs : CausalTensor < T > , rhs : CausalTensor < T > ) -> EinSumAST < T > {
180- let lhs_leaf = EinSumOp :: tensor_source ( lhs) ;
181- let rhs_leaf = EinSumOp :: tensor_source ( rhs) ;
185+ pub fn tensor_product < L : Into < CausalTensor < T > > , R : Into < CausalTensor < T > > > (
186+ lhs : L ,
187+ rhs : R ,
188+ ) -> EinSumAST < T > {
189+ let lhs_leaf = EinSumOp :: tensor_source ( lhs. into ( ) ) ;
190+ let rhs_leaf = EinSumOp :: tensor_source ( rhs. into ( ) ) ;
182191 EinSumAST :: with_children ( EinSumOp :: TensorProduct , vec ! [ lhs_leaf, rhs_leaf] )
183192 }
184193
@@ -194,9 +203,12 @@ impl<T> EinSumOp<T> {
194203 /// # Returns
195204 ///
196205 /// An `EinSumAST<T>` node representing the element-wise product operation.
197- pub fn element_wise_product ( lhs : CausalTensor < T > , rhs : CausalTensor < T > ) -> EinSumAST < T > {
198- let lhs_leaf = EinSumOp :: tensor_source ( lhs) ;
199- let rhs_leaf = EinSumOp :: tensor_source ( rhs) ;
206+ pub fn element_wise_product < L : Into < CausalTensor < T > > , R : Into < CausalTensor < T > > > (
207+ lhs : L ,
208+ rhs : R ,
209+ ) -> EinSumAST < T > {
210+ let lhs_leaf = EinSumOp :: tensor_source ( lhs. into ( ) ) ;
211+ let rhs_leaf = EinSumOp :: tensor_source ( rhs. into ( ) ) ;
200212 EinSumAST :: with_children ( EinSumOp :: ElementWiseProduct , vec ! [ lhs_leaf, rhs_leaf] )
201213 }
202214
@@ -212,8 +224,8 @@ impl<T> EinSumOp<T> {
212224 /// # Returns
213225 ///
214226 /// An `EinSumAST<T>` node representing the transpose operation.
215- pub fn transpose ( operand : CausalTensor < T > , new_order : Vec < usize > ) -> EinSumAST < T > {
216- let operand_leaf = EinSumOp :: tensor_source ( operand) ;
227+ pub fn transpose < O : Into < CausalTensor < T > > > ( operand : O , new_order : Vec < usize > ) -> EinSumAST < T > {
228+ let operand_leaf = EinSumOp :: tensor_source ( operand. into ( ) ) ;
217229 EinSumAST :: with_children ( EinSumOp :: Transpose { new_order } , vec ! [ operand_leaf] )
218230 }
219231
@@ -230,12 +242,12 @@ impl<T> EinSumOp<T> {
230242 /// # Returns
231243 ///
232244 /// An `EinSumAST<T>` node representing the diagonal extraction operation.
233- pub fn diagonal_extraction (
234- operand : CausalTensor < T > ,
245+ pub fn diagonal_extraction < O : Into < CausalTensor < T > > > (
246+ operand : O ,
235247 axes1 : usize ,
236248 axes2 : usize ,
237249 ) -> EinSumAST < T > {
238- let operand_leaf = EinSumOp :: tensor_source ( operand) ;
250+ let operand_leaf = EinSumOp :: tensor_source ( operand. into ( ) ) ;
239251 EinSumAST :: with_children (
240252 EinSumOp :: DiagonalExtraction { axes1, axes2 } ,
241253 vec ! [ operand_leaf] ,
@@ -254,9 +266,12 @@ impl<T> EinSumOp<T> {
254266 /// # Returns
255267 ///
256268 /// An `EinSumAST<T>` node representing the batch matrix multiplication operation.
257- pub fn batch_mat_mul ( lhs : CausalTensor < T > , rhs : CausalTensor < T > ) -> EinSumAST < T > {
258- let lhs_leaf = EinSumOp :: tensor_source ( lhs) ;
259- let rhs_leaf = EinSumOp :: tensor_source ( rhs) ;
269+ pub fn batch_mat_mul < L : Into < CausalTensor < T > > , R : Into < CausalTensor < T > > > (
270+ lhs : L ,
271+ rhs : R ,
272+ ) -> EinSumAST < T > {
273+ let lhs_leaf = EinSumOp :: tensor_source ( lhs. into ( ) ) ;
274+ let rhs_leaf = EinSumOp :: tensor_source ( rhs. into ( ) ) ;
260275 EinSumAST :: with_children ( EinSumOp :: BatchMatMul , vec ! [ lhs_leaf, rhs_leaf] )
261276 }
262277}
0 commit comments