Skip to content

Commit 0fbbfd2

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add spec: max_pool3d_with_indices
Reviewed By: zonglinpeng Differential Revision: D76528169 fbshipit-source-id: 0c7602f6c0d75930e1b7810bb4aa8fd8a3e9d054
1 parent 9b7ea9c commit 0fbbfd2

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

facto/specdb/__init__.py

Whitespace-only changes.

facto/specdb/db.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,6 +2500,69 @@
25002500
OutArg(ArgType.Tensor, name="indices"),
25012501
],
25022502
),
2503+
Spec( # TODO(mcandales): Calibrate.
2504+
op="max_pool3d_with_indices.default", # (Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
2505+
inspec=[
2506+
InPosArg( # self
2507+
ArgType.Tensor,
2508+
name="self",
2509+
deps=[1, 2, 3, 4, 5],
2510+
constraints=[
2511+
cp.Dtype.Ne(lambda deps: torch.bool),
2512+
cp.Rank.In(lambda deps: [4, 5]),
2513+
# cp.Size.Ge(lambda deps, r, d: 0 if d == 0 and r == 5 else 1),
2514+
cp.Size.Ge(
2515+
lambda deps, r, d: fn.pool_input_size_min(
2516+
3, deps[0], deps[1], deps[2], deps[3], deps[4], r, d
2517+
)
2518+
),
2519+
],
2520+
),
2521+
InPosArg( # kernel_size
2522+
ArgType.LengthList,
2523+
name="kernel_size",
2524+
constraints=[
2525+
cp.Length.In(lambda deps: [1, 3]),
2526+
cp.Value.Ge(lambda deps, length, ix: 1),
2527+
],
2528+
),
2529+
InPosArg( # stride
2530+
ArgType.LengthList,
2531+
name="stride",
2532+
constraints=[
2533+
cp.Length.In(lambda deps: [0, 1, 3]),
2534+
cp.Value.Ge(lambda deps, length, ix: 1),
2535+
],
2536+
),
2537+
InPosArg( # padding
2538+
ArgType.LengthList,
2539+
name="padding",
2540+
deps=[1],
2541+
constraints=[
2542+
cp.Length.In(lambda deps: [1, 3]),
2543+
cp.Value.Ge(lambda deps, length, ix: 0),
2544+
cp.Value.Le(
2545+
lambda deps, length, ix: fn.pool_padding_max(
2546+
deps[0], length, ix
2547+
)
2548+
),
2549+
],
2550+
),
2551+
InPosArg( # dilation
2552+
ArgType.LengthList,
2553+
name="dilation",
2554+
constraints=[
2555+
cp.Length.In(lambda deps: [1, 3]),
2556+
cp.Value.Ge(lambda deps, length, ix: 1),
2557+
],
2558+
),
2559+
InPosArg(ArgType.Bool, name="ceil_mode"), # ceil_mode
2560+
],
2561+
outspec=[
2562+
OutArg(ArgType.Tensor, name="out"),
2563+
OutArg(ArgType.Tensor, name="indices"),
2564+
],
2565+
),
25032566
Spec(
25042567
op="maximum.default", # (Tensor self, Tensor other) -> Tensor
25052568
inspec=[

facto/specdb/function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ def pool_input_size_min(
459459
kernel_ndim, kernel_size, stride, padding, dilation, ceil_mode, rank, dim
460460
):
461461
if dim == 0:
462-
return 0 if rank == 4 else 1
463-
if dim == 1 and rank == 4:
462+
return 0 if rank == (kernel_ndim + 2) else 1
463+
if dim == 1 and rank == (kernel_ndim + 2):
464464
return 1
465465

466466
kdim = dim - (rank - kernel_ndim)

0 commit comments

Comments
 (0)