Skip to content

Commit 8896942

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
SpecDB: Add spec: prod.default & prod.dim_int
Reviewed By: SS-JIA Differential Revision: D52927778 fbshipit-source-id: 804629abf8262cb5df06d902526173f03dddd6d5
1 parent 33e6ea3 commit 8896942

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

specdb/db.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,6 +2410,64 @@
24102410
],
24112411
outspec=[OutArg(ArgType.Tensor)],
24122412
),
2413+
Spec(
2414+
op="prod.default", # (Tensor self, *, ScalarType? dtype=None) -> Tensor
2415+
inspec=[
2416+
InPosArg(ArgType.Tensor, name="self"),
2417+
InKwArg(ArgType.ScalarTypeOpt, name="dtype"),
2418+
],
2419+
outspec=[
2420+
OutArg(
2421+
ArgType.Tensor,
2422+
deps=[0, 1],
2423+
constraints=[
2424+
cp.Dtype.Eq(lambda deps: deps[1] if deps[1] is not None else None),
2425+
cp.Dtype.Eq(
2426+
lambda deps: torch.long
2427+
if deps[1] is None and deps[0].dtype in dt._int_and_bool
2428+
else None
2429+
),
2430+
cp.Dtype.Eq(
2431+
lambda deps: deps[0].dtype
2432+
if deps[1] is None and deps[0].dtype not in dt._int_and_bool
2433+
else None
2434+
),
2435+
],
2436+
),
2437+
],
2438+
),
2439+
Spec(
2440+
op="prod.dim_int", # (Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
2441+
inspec=[
2442+
InPosArg(ArgType.Tensor, name="self"),
2443+
InPosArg(
2444+
ArgType.Dim,
2445+
name="dim",
2446+
deps=[0],
2447+
constraints=DimDefault,
2448+
),
2449+
InPosArg(ArgType.Bool, name="keepdim"),
2450+
InKwArg(ArgType.ScalarTypeOpt, name="dtype"),
2451+
],
2452+
outspec=[
2453+
OutArg(
2454+
ArgType.Tensor,
2455+
constraints=[
2456+
cp.Dtype.Eq(lambda deps: deps[3] if deps[3] is not None else None),
2457+
cp.Dtype.Eq(
2458+
lambda deps: torch.long
2459+
if deps[3] is None and deps[0].dtype in dt._int_and_bool
2460+
else None
2461+
),
2462+
cp.Dtype.Eq(
2463+
lambda deps: deps[0].dtype
2464+
if deps[3] is None and deps[0].dtype not in dt._int_and_bool
2465+
else None
2466+
),
2467+
],
2468+
),
2469+
],
2470+
),
24132471
Spec(
24142472
op="reciprocal.default", # (Tensor self) -> Tensor
24152473
inspec=[

0 commit comments

Comments
 (0)