@@ -317,45 +317,69 @@ def _tfl_concatenation_lowering(
317317 )
318318
319319
320+ @lower (torch .ops .tfl .fill .default )
321+ def _tfl_fill_lowering (
322+ lctx : LoweringContext ,
323+ dims : Sequence [int | ir .Value ],
324+ fill_value : ir .Value ,
325+ ) -> ir .Value :
326+ dims_ir_value = lowering_utils .convert_shape_to_ir_value (dims )
327+ fill_value_ir_value = lowering_utils .convert_to_ir_value (fill_value )
328+
329+ # Ensure fill_value_ir_value is a scalar (0-D tensor) for TFLite Fill op.
330+ # The TFLite Fill kernel expects the value to be a 0-D tensor.
331+ if isinstance (fill_value_ir_value .type , ir .RankedTensorType ):
332+ tensor_type = fill_value_ir_value .type
333+ # If it's a 1-D tensor with a single element, reshape to 0-D.
334+ if list (tensor_type .shape ) == [1 ]:
335+ scalar_type = ir .RankedTensorType .get ([], tensor_type .element_type )
336+ fill_value_ir_value = stablehlo .reshape (scalar_type , fill_value_ir_value )
337+
338+ # Determine the target element type from the node's output definition.
339+ result_types = lowering_utils .node_meta_to_ir_types (lctx .node )
340+ if not result_types or not isinstance (result_types [0 ], ir .RankedTensorType ):
341+ raise ValueError (
342+ "tfl.fill: Unable to determine result tensor type or result is not a"
343+ " ranked tensor."
344+ )
345+ target_element_type = result_types [0 ].element_type
346+
347+ # Ensure fill_value_ir_value is a RankedTensorType to access its properties.
348+ if not isinstance (fill_value_ir_value .type , ir .RankedTensorType ):
349+ raise TypeError (
350+ "tfl.fill: fill_value_ir_value expected to be RankedTensorType, got"
351+ f" { fill_value_ir_value .type } "
352+ )
353+
354+ current_fill_tensor_type = fill_value_ir_value .type
355+ current_element_type = current_fill_tensor_type .element_type
356+
357+ # If the element type of the (scalar) fill_value doesn't match the target
358+ # output element type, cast fill_value_ir_value to the target_element_type
359+ # while maintaining its current shape (which should be scalar).
360+ if current_element_type != target_element_type :
361+ cast_to_type = ir .RankedTensorType .get (
362+ current_fill_tensor_type .shape , target_element_type
363+ )
364+ fill_value_ir_value = stablehlo .convert (cast_to_type , fill_value_ir_value )
365+
366+ return _ir_operation (
367+ "tfl.fill" ,
368+ results = result_types ,
369+ operands = [dims_ir_value , fill_value_ir_value ],
370+ )
371+
372+
320373@lower (torch .ops .tfl .reshape .default )
321374def _tfl_reshape_lowering (
322375 lctx : LoweringContext ,
323376 x : ir .Value ,
324377 shape : Sequence [int | ir .Value ],
325378) -> ir .Value :
326- # Check if all elements in the shape sequence are integers.
327- if not shape or all (isinstance (dim , int ) for dim in shape ):
328- # If all are integers, create a constant numpy array.
329- # Assuming int32 is the required type for TFLite shape tensors.
330- shape_ir_value = lowering_utils .numpy_array_constant (
331- np .array (shape , dtype = np .int32 )
332- )
333- else :
334- # Handle mixed int and ir.Value shape sequence
335- processed_dims = []
336- for dim in shape :
337- if isinstance (dim , int ):
338- # Convert int to a constant 1D tensor
339- shape_ir_value = lowering_utils .numpy_array_constant (
340- np .array ([dim ], dtype = np .int32 )
341- )
342- processed_dims .append (shape_ir_value )
343- else :
344- assert isinstance (dim , ir .Value )
345- # Convert ir.Value to a constant 1D tensor
346- new_type = ir .RankedTensorType .get ([1 ], dim .type .element_type )
347- reshape_dim = stablehlo .reshape (new_type , dim )
348- processed_dims .append (reshape_dim )
349-
350- shape_ir_value = stablehlo .concatenate (
351- processed_dims ,
352- dimension = 0 ,
353- )
354-
355379 return _ir_operation (
356380 "tfl.reshape" ,
357381 results = lowering_utils .node_meta_to_ir_types (lctx .node ),
358- operands = [x , shape_ir_value ],
382+ operands = [x , lowering_utils . convert_shape_to_ir_value ( shape ) ],
359383 )
360384
361385
0 commit comments