Skip to content

Commit 4c7affb

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
SpecDB: Add spec: scatter.value
Reviewed By: JacobSzwejbka Differential Revision: D61874510 fbshipit-source-id: 11622265cc0fda1b75104697b7cc4443fe6aeca5
1 parent b8d1403 commit 4c7affb

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

specdb/db.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3478,6 +3478,78 @@
34783478
],
34793479
outspec=[OutArg(ArgType.Tensor)],
34803480
),
3481+
Spec( # TODO(mcandales): Calibrate.
3482+
op="scatter.value", # (Tensor self, int dim, Tensor index, Scalar value) -> Tensor
3483+
inspec=[
3484+
InPosArg(ArgType.Tensor, name="self"),
3485+
InPosArg(
3486+
ArgType.Dim,
3487+
name="dim",
3488+
deps=[0],
3489+
constraints=[
3490+
cp.Value.In(lambda deps: fn.dim_non_zero_size(deps[0])),
3491+
],
3492+
),
3493+
InPosArg(
3494+
ArgType.Tensor,
3495+
name="index",
3496+
deps=[0, 1],
3497+
# TODO(mcandales) Handle index.numel() == 0 case
3498+
constraints=[
3499+
cp.Dtype.Eq(lambda deps: torch.long),
3500+
cp.Rank.Eq(
3501+
lambda deps: deps[0].dim() if deps[0].dim() >= 2 else None
3502+
),
3503+
cp.Rank.In(
3504+
lambda deps: [0, 1] if deps[0].dim() in [0, 1] else None
3505+
),
3506+
cp.Size.Le(
3507+
lambda deps, r, d: (
3508+
fn.safe_size(deps[0], d)
3509+
if d != fn.normalize(deps[1], deps[0].dim())
3510+
else None
3511+
)
3512+
),
3513+
cp.Value.Ge(lambda deps, dtype, struct: 0),
3514+
cp.Value.Le(
3515+
lambda deps, dtype, struct: (
3516+
0
3517+
if deps[0].dim() == 0
3518+
else max(0, fn.safe_size(deps[0], deps[1]) - 1)
3519+
)
3520+
),
3521+
],
3522+
),
3523+
InPosArg(
3524+
ArgType.Scalar,
3525+
name="value",
3526+
deps=[0],
3527+
constraints=[
3528+
cp.Value.NotIn(
3529+
lambda deps, dtype: (
3530+
[float("-inf"), float("inf")]
3531+
if deps[0].dtype not in dt._floating
3532+
else None
3533+
)
3534+
),
3535+
cp.Value.Ge(
3536+
lambda deps, dtype: fn.dtype_lower_bound(deps[0].dtype)
3537+
),
3538+
cp.Value.Le(
3539+
lambda deps, dtype: fn.dtype_upper_bound(deps[0].dtype)
3540+
),
3541+
],
3542+
),
3543+
],
3544+
outspec=[
3545+
OutArg(
3546+
ArgType.Tensor,
3547+
constraints=[
3548+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
3549+
],
3550+
),
3551+
],
3552+
),
34813553
Spec( # TODO(mcandales): Calibrate.
34823554
op="scatter_add.default", # (Tensor self, int dim, Tensor index, Tensor src) -> Tensor
34833555
inspec=[

0 commit comments

Comments
 (0)