Skip to content

Commit ce51d74

Browse files
authored
Add unfold op (#21685)
* add unfold op * fix. * fix error. * fix error. * fix error. * fix jax and tf backend ,add numpy implement. * fix document
1 parent 7c14344 commit ce51d74

File tree

11 files changed

+457
-0
lines changed

11 files changed

+457
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
from keras.src.ops.nn import squareplus as squareplus
118118
from keras.src.ops.nn import tanh_shrink as tanh_shrink
119119
from keras.src.ops.nn import threshold as threshold
120+
from keras.src.ops.nn import unfold as unfold
120121
from keras.src.ops.numpy import abs as abs
121122
from keras.src.ops.numpy import absolute as absolute
122123
from keras.src.ops.numpy import add as add

keras/api/_tf_keras/keras/ops/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@
5757
from keras.src.ops.nn import squareplus as squareplus
5858
from keras.src.ops.nn import tanh_shrink as tanh_shrink
5959
from keras.src.ops.nn import threshold as threshold
60+
from keras.src.ops.nn import unfold as unfold

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
from keras.src.ops.nn import squareplus as squareplus
118118
from keras.src.ops.nn import tanh_shrink as tanh_shrink
119119
from keras.src.ops.nn import threshold as threshold
120+
from keras.src.ops.nn import unfold as unfold
120121
from keras.src.ops.numpy import abs as abs
121122
from keras.src.ops.numpy import absolute as absolute
122123
from keras.src.ops.numpy import add as add

keras/api/ops/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@
5757
from keras.src.ops.nn import squareplus as squareplus
5858
from keras.src.ops.nn import tanh_shrink as tanh_shrink
5959
from keras.src.ops.nn import threshold as threshold
60+
from keras.src.ops.nn import unfold as unfold

keras/src/backend/jax/nn.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,3 +1413,46 @@ def _reshape_to_grouped(t):
14131413
)
14141414
encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)
14151415
return jnp.reshape(encoded, output_shape)
1416+
1417+
1418+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1419+
"""JAX implementation of Unfold.
1420+
Extract sliding local blocks from a **NCHW** batched image tensor.
1421+
1422+
Args:
1423+
input: 4-D tensor, shape (N, C, H, W) **required**.
1424+
kernel_size: int or (kH, kW)
1425+
dilation: int or (dH, dW), default 1
1426+
padding: int or (pH, pW), default 0
1427+
stride: int or (sH, sW), default 1
1428+
1429+
Returns:
1430+
3-D tensor, shape (N, C*kH*kW, L)
1431+
"""
1432+
1433+
def _pair(x):
1434+
return (x, x) if isinstance(x, int) else x
1435+
1436+
k = _pair(kernel_size)
1437+
d = _pair(dilation)
1438+
p = _pair(padding)
1439+
s = _pair(stride)
1440+
1441+
N, C, H, W = input.shape
1442+
1443+
# ---- padding ----
1444+
if any(_ > 0 for _ in p):
1445+
input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))
1446+
1447+
patches = lax.conv_general_dilated_patches(
1448+
input,
1449+
filter_shape=k,
1450+
window_strides=s,
1451+
padding="VALID", # has padde
1452+
rhs_dilation=d,
1453+
dimension_numbers=("NCHW", "OIHW", "NCHW"), # only support 'NCHW'
1454+
) # shape: (N, C*kH*kW, oH, oW)
1455+
1456+
# ---- reshape -> (N, C*kH*kW, L) ----
1457+
_, CKK, oH, oW = patches.shape
1458+
return patches.reshape(N, CKK, oH * oW)

keras/src/backend/numpy/nn.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,3 +1176,56 @@ def dot_product_attention(
11761176
return _dot_product_attention_xla(
11771177
query, key, value, bias, mask, is_causal, scale
11781178
)
1179+
1180+
1181+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1182+
"""NumPy implementation of Unfold.
1183+
Extract sliding local blocks from a **NCHW** batched image tensor.
1184+
1185+
Args:
1186+
input: 4-D tensor, shape (N, C, H, W) **required**.
1187+
kernel_size: int or (kH, kW)
1188+
dilation: int or (dH, dW), default 1
1189+
padding: int or (pH, pW), default 0
1190+
stride: int or (sH, sW), default 1
1191+
1192+
Returns:
1193+
3-D tensor, shape (N, C*kH*kW, L)
1194+
"""
1195+
1196+
def _pair(x):
1197+
return (x, x) if isinstance(x, int) else x
1198+
1199+
k = _pair(kernel_size)
1200+
d = _pair(dilation)
1201+
p = _pair(padding)
1202+
s = _pair(stride)
1203+
1204+
N, C, H, W = input.shape
1205+
1206+
# ---- padding ----
1207+
if any(_ > 0 for _ in p):
1208+
input = np.pad(
1209+
input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode="constant"
1210+
)
1211+
1212+
# ---- spatial size ----
1213+
oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1
1214+
oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1
1215+
1216+
i0 = np.arange(0, oH) * s[0]
1217+
j0 = np.arange(0, oW) * s[1]
1218+
i, j = np.meshgrid(i0, j0, indexing="ij") # shape (oH, oW)
1219+
i = i.reshape(-1)
1220+
j = j.reshape(-1)
1221+
1222+
# ---- flatten patches ----
1223+
patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)
1224+
for idx in range(k[0]):
1225+
for jdx in range(k[1]):
1226+
patches[:, :, idx, jdx, :] = input[
1227+
:, :, i + idx * d[0], j + jdx * d[1]
1228+
]
1229+
1230+
# ---- reshape -> (N, C*kH*kW, L) ----
1231+
return patches.reshape(N, C * k[0] * k[1], -1)

keras/src/backend/openvino/nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,7 @@ def dot_product_attention(
502502
raise NotImplementedError(
503503
"`dot_product_attention` is not supported with openvino backend"
504504
)
505+
506+
507+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
508+
raise NotImplementedError("`unfold` is not supported with openvino backend")

keras/src/backend/tensorflow/nn.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,3 +1077,50 @@ def dot_product_attention(
10771077
return _dot_product_attention_xla(
10781078
query, key, value, bias, mask, is_causal, scale
10791079
)
1080+
1081+
1082+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1083+
"""Tensorflow implementation of Unfold.
1084+
Extract sliding local blocks from a **NCHW** batched image tensor.
1085+
1086+
Args:
1087+
input: 4-D tensor, shape (N, C, H, W) **required**.
1088+
kernel_size: int or (kH, kW)
1089+
dilation: int or (dH, dW), default 1
1090+
padding: int or (pH, pW), default 0
1091+
stride: int or (sH, sW), default 1
1092+
1093+
Returns:
1094+
3-D tensor, shape (N, C*kH*kW, L)
1095+
"""
1096+
k = (
1097+
(kernel_size, kernel_size)
1098+
if isinstance(kernel_size, int)
1099+
else kernel_size
1100+
)
1101+
d = (dilation, dilation) if isinstance(dilation, int) else dilation
1102+
p = (padding, padding) if isinstance(padding, int) else padding
1103+
s = (stride, stride) if isinstance(stride, int) else stride
1104+
N, C, H, W = input.shape
1105+
1106+
# ---- padding ----
1107+
if any(_ > 0 for _ in p):
1108+
input = tf.pad(input, [[0, 0], [0, 0], [p[0], p[0]], [p[1], p[1]]])
1109+
x = tf.transpose(input, [0, 2, 3, 1]) # (N, H, W, C)
1110+
patches = tf.image.extract_patches(
1111+
images=x,
1112+
sizes=[1, k[0], k[1], 1],
1113+
strides=[1, s[0], s[1], 1],
1114+
rates=[1, d[0], d[1], 1],
1115+
padding="VALID",
1116+
) # (N, nH, nW, kH*kW*C)
1117+
1118+
N, nH, nW, D = patches.shape
1119+
patches = tf.reshape(
1120+
patches, [N, nH, nW, k[0], k[1], C]
1121+
) # (N, nH, nW, kH, kW, C)
1122+
patches = tf.transpose(
1123+
patches, [0, 5, 3, 4, 1, 2]
1124+
) # (N, C, kH, kW, nH, nW)
1125+
patches = tf.reshape(patches, [N, C * k[0] * k[1], nH * nW])
1126+
return patches

keras/src/backend/torch/nn.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,26 @@ def dot_product_attention(
10921092
scale=scale,
10931093
)
10941094
return torch.transpose(attention_output, axis1, axis0)
1095+
1096+
1097+
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
1098+
"""Native PyTorch implementation of Unfold.
1099+
Extract sliding local blocks from a **NCHW** batched image tensor.
1100+
1101+
Args:
1102+
input: 4-D tensor, shape (N, C, H, W) **required**.
1103+
kernel_size: int or (kH, kW)
1104+
dilation: int or (dH, dW), default 1
1105+
padding: int or (pH, pW), default 0
1106+
stride: int or (sH, sW), default 1
1107+
1108+
Returns:
1109+
3-D tensor, shape (N, C*kH*kW, L)
1110+
"""
1111+
return tnn.unfold(
1112+
input,
1113+
kernel_size=kernel_size,
1114+
dilation=dilation,
1115+
padding=padding,
1116+
stride=stride,
1117+
)

keras/src/ops/nn.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3055,3 +3055,93 @@ def _polar(abs_, angle):
30553055
result = backend.math._get_complex_tensor_from_tuple((real, imaginary))
30563056

30573057
return result
3058+
3059+
3060+
class Unfold(Operation):
3061+
def __init__(
3062+
self, kernel_size, dilation=1, padding=0, stride=1, *, name=None
3063+
):
3064+
super().__init__(name=name)
3065+
self.kernel_size = kernel_size
3066+
self.dilation = dilation
3067+
self.padding = padding
3068+
self.stride = stride
3069+
3070+
def compute_output_spec(self, x):
3071+
N, C, H, W = x.shape
3072+
3073+
def _pair(x):
3074+
return (x, x) if isinstance(x, int) else x
3075+
3076+
kH, kW = _pair(self.kernel_size)
3077+
dH, dW = _pair(self.dilation)
3078+
pH, pW = _pair(self.padding)
3079+
sH, sW = _pair(self.stride)
3080+
3081+
def out_size(L, k, d, p, s):
3082+
return (L + 2 * p - d * (k - 1) - 1) // s + 1
3083+
3084+
outH = out_size(H, kH, dH, pH, sH)
3085+
outW = out_size(W, kW, dW, pW, sW)
3086+
return KerasTensor(shape=(N, C * kH * kW, outH * outW), dtype=x.dtype)
3087+
3088+
def call(self, x):
3089+
return _unfold(
3090+
x, self.kernel_size, self.dilation, self.padding, self.stride
3091+
)
3092+
3093+
3094+
@keras_export(["keras.ops.unfold", "keras.ops.nn.unfold"])
3095+
def unfold(x, kernel_size, dilation=1, padding=0, stride=1):
3096+
"""Extract sliding local blocks from a 4-D input (batched image).
3097+
3098+
This operation is known as **im2col** when used with convolution.
3099+
It rearranges the image into overlapping or non-overlapping patches
3100+
and returns a tensor whose *depth* (last axis) contains the flattened
3101+
patches.
3102+
3103+
Args:
3104+
x: A 4-D tensor of shape `(N, C, H, W)` (**channels-first** format).
3105+
kernel_size: int or tuple of two ints, the size of the sliding window
3106+
`(kH, kW)`. If a single int is given, it is used for both
3107+
dimensions.
3108+
dilation: int or tuple of two ints, the spacing between kernel points
3109+
(a.k.a. **dilation** or **atrous** convolution). Default: 1.
3110+
padding: int or tuple of two ints, the amount of zero-padding to apply
3111+
to both spatial dimensions. Default: 0.
3112+
stride: int or tuple of two ints, the step size of the sliding window.
3113+
Default: 1.
3114+
3115+
Returns:
3116+
A 3-D tensor of shape `(N, C * kH * kW, L)` where
3117+
`L = num_patches_H * num_patches_W` is the total number of patches
3118+
extracted.
3119+
3120+
Example:
3121+
3122+
>>> x = keras.ops.ones((1, 2, 4, 4))
3123+
>>> patches = keras.ops.unfold(x, kernel_size=2, stride=2)
3124+
>>> patches.shape
3125+
(1, 8, 4)
3126+
3127+
"""
3128+
input_shape = x.shape
3129+
ndims = len(input_shape)
3130+
if ndims != 4:
3131+
raise ValueError(
3132+
f"Input must be a 4D tensor. Received: input.shape={input_shape}"
3133+
)
3134+
if any_symbolic_tensors((x,)):
3135+
return Unfold(kernel_size, dilation, padding, stride).symbolic_call(x)
3136+
return _unfold(x, kernel_size, dilation, padding, stride)
3137+
3138+
3139+
def _unfold(x, kernel_size, dilation=1, padding=0, stride=1):
3140+
"""Internal implementation of unfold."""
3141+
return backend.nn.unfold(
3142+
x,
3143+
kernel_size=kernel_size,
3144+
dilation=dilation,
3145+
padding=padding,
3146+
stride=stride,
3147+
)

0 commit comments

Comments
 (0)