Skip to content

Commit 33e6ea3

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
SpecDB: Add spec: maximum & minimum
Reviewed By: SS-JIA Differential Revision: D52927776 fbshipit-source-id: f99f9056d05ca1bb5854d0784c8397c45a4a6cb2
1 parent 9a31777 commit 33e6ea3

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

specdb/db.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,34 @@
20392039
OutArg(ArgType.Tensor, name="indices"),
20402040
],
20412041
),
2042+
Spec(
2043+
op="maximum.default", # (Tensor self, Tensor other) -> Tensor
2044+
inspec=[
2045+
InPosArg(ArgType.Tensor, name="self"),
2046+
InPosArg(
2047+
ArgType.Tensor,
2048+
name="other",
2049+
deps=[0],
2050+
constraints=[
2051+
cp.Size.In(
2052+
lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)
2053+
),
2054+
],
2055+
),
2056+
],
2057+
outspec=[
2058+
OutArg(
2059+
ArgType.Tensor,
2060+
constraints=[
2061+
cp.Dtype.In(
2062+
lambda deps: dt.can_cast_from(
2063+
torch.promote_types(deps[0].dtype, deps[1].dtype)
2064+
)
2065+
),
2066+
],
2067+
)
2068+
],
2069+
),
20422070
Spec(
20432071
op="mean.dim", # (Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
20442072
inspec=[
@@ -2102,7 +2130,18 @@
21022130
],
21032131
),
21042132
],
2105-
outspec=[OutArg(ArgType.Tensor)],
2133+
outspec=[
2134+
OutArg(
2135+
ArgType.Tensor,
2136+
constraints=[
2137+
cp.Dtype.In(
2138+
lambda deps: dt.can_cast_from(
2139+
torch.promote_types(deps[0].dtype, deps[1].dtype)
2140+
)
2141+
),
2142+
],
2143+
)
2144+
],
21062145
),
21072146
Spec(
21082147
op="mm.default", # (Tensor self, Tensor mat2) -> Tensor

0 commit comments

Comments
 (0)