@@ -3055,3 +3055,93 @@ def _polar(abs_, angle):
3055
3055
result = backend .math ._get_complex_tensor_from_tuple ((real , imaginary ))
3056
3056
3057
3057
return result
3058
+
3059
+
3060
+ class Unfold (Operation ):
3061
+ def __init__ (
3062
+ self , kernel_size , dilation = 1 , padding = 0 , stride = 1 , * , name = None
3063
+ ):
3064
+ super ().__init__ (name = name )
3065
+ self .kernel_size = kernel_size
3066
+ self .dilation = dilation
3067
+ self .padding = padding
3068
+ self .stride = stride
3069
+
3070
+ def compute_output_spec (self , x ):
3071
+ N , C , H , W = x .shape
3072
+
3073
+ def _pair (x ):
3074
+ return (x , x ) if isinstance (x , int ) else x
3075
+
3076
+ kH , kW = _pair (self .kernel_size )
3077
+ dH , dW = _pair (self .dilation )
3078
+ pH , pW = _pair (self .padding )
3079
+ sH , sW = _pair (self .stride )
3080
+
3081
+ def out_size (L , k , d , p , s ):
3082
+ return (L + 2 * p - d * (k - 1 ) - 1 ) // s + 1
3083
+
3084
+ outH = out_size (H , kH , dH , pH , sH )
3085
+ outW = out_size (W , kW , dW , pW , sW )
3086
+ return KerasTensor (shape = (N , C * kH * kW , outH * outW ), dtype = x .dtype )
3087
+
3088
+ def call (self , x ):
3089
+ return _unfold (
3090
+ x , self .kernel_size , self .dilation , self .padding , self .stride
3091
+ )
3092
+
3093
+
3094
+ @keras_export (["keras.ops.unfold" , "keras.ops.nn.unfold" ])
3095
+ def unfold (x , kernel_size , dilation = 1 , padding = 0 , stride = 1 ):
3096
+ """Extract sliding local blocks from a 4-D input (batched image).
3097
+
3098
+ This operation is known as **im2col** when used with convolution.
3099
+ It rearranges the image into overlapping or non-overlapping patches
3100
+ and returns a tensor whose *depth* (last axis) contains the flattened
3101
+ patches.
3102
+
3103
+ Args:
3104
+ x: A 4-D tensor of shape `(N, C, H, W)` (**channels-first** format).
3105
+ kernel_size: int or tuple of two ints, the size of the sliding window
3106
+ `(kH, kW)`. If a single int is given, it is used for both
3107
+ dimensions.
3108
+ dilation: int or tuple of two ints, the spacing between kernel points
3109
+ (a.k.a. **dilation** or **atrous** convolution). Default: 1.
3110
+ padding: int or tuple of two ints, the amount of zero-padding to apply
3111
+ to both spatial dimensions. Default: 0.
3112
+ stride: int or tuple of two ints, the step size of the sliding window.
3113
+ Default: 1.
3114
+
3115
+ Returns:
3116
+ A 3-D tensor of shape `(N, C * kH * kW, L)` where
3117
+ `L = num_patches_H * num_patches_W` is the total number of patches
3118
+ extracted.
3119
+
3120
+ Example:
3121
+
3122
+ >>> x = keras.ops.ones((1, 2, 4, 4))
3123
+ >>> patches = keras.ops.unfold(x, kernel_size=2, stride=2)
3124
+ >>> patches.shape
3125
+ (1, 8, 4)
3126
+
3127
+ """
3128
+ input_shape = x .shape
3129
+ ndims = len (input_shape )
3130
+ if ndims != 4 :
3131
+ raise ValueError (
3132
+ f"Input must be a 4D tensor. Received: input.shape={ input_shape } "
3133
+ )
3134
+ if any_symbolic_tensors ((x ,)):
3135
+ return Unfold (kernel_size , dilation , padding , stride ).symbolic_call (x )
3136
+ return _unfold (x , kernel_size , dilation , padding , stride )
3137
+
3138
+
3139
+ def _unfold (x , kernel_size , dilation = 1 , padding = 0 , stride = 1 ):
3140
+ """Internal implementation of unfold."""
3141
+ return backend .nn .unfold (
3142
+ x ,
3143
+ kernel_size = kernel_size ,
3144
+ dilation = dilation ,
3145
+ padding = padding ,
3146
+ stride = stride ,
3147
+ )
0 commit comments