Skip to content

Commit b8d1403

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
SpecDB: Add spec: gather
Reviewed By: JacobSzwejbka Differential Revision: D61822096 fbshipit-source-id: ad9fc4129beaa2a98c9a61d73704f807b5e5a939
1 parent a7036d6 commit b8d1403

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

specdb/db.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,59 @@
19431943
OutArg(ArgType.Tensor),
19441944
],
19451945
),
1946+
Spec( # TODO(mcandales): Calibrate.
1947+
op="gather.default", # (Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
1948+
inspec=[
1949+
InPosArg(ArgType.Tensor, name="self"),
1950+
InPosArg(
1951+
ArgType.Dim,
1952+
name="dim",
1953+
deps=[0],
1954+
constraints=[
1955+
cp.Value.In(lambda deps: fn.dim_non_zero_size(deps[0])),
1956+
],
1957+
),
1958+
InPosArg(
1959+
ArgType.Tensor,
1960+
name="index",
1961+
deps=[0, 1],
1962+
# TODO(mcandales) Handle index.numel() == 0 case
1963+
constraints=[
1964+
cp.Dtype.Eq(lambda deps: torch.long),
1965+
cp.Rank.Eq(
1966+
lambda deps: deps[0].dim() if deps[0].dim() >= 2 else None
1967+
),
1968+
cp.Rank.In(
1969+
lambda deps: [0, 1] if deps[0].dim() in [0, 1] else None
1970+
),
1971+
cp.Size.Le(
1972+
lambda deps, r, d: (
1973+
fn.safe_size(deps[0], d)
1974+
if d != fn.normalize(deps[1], deps[0].dim())
1975+
else None
1976+
)
1977+
),
1978+
cp.Value.Ge(lambda deps, dtype, struct: 0),
1979+
cp.Value.Le(
1980+
lambda deps, dtype, struct: (
1981+
0
1982+
if deps[0].dim() == 0
1983+
else max(0, fn.safe_size(deps[0], deps[1]) - 1)
1984+
)
1985+
),
1986+
],
1987+
),
1988+
InKwArg(ArgType.Bool, name="sparse_grad"),
1989+
],
1990+
outspec=[
1991+
OutArg(
1992+
ArgType.Tensor,
1993+
constraints=[
1994+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
1995+
],
1996+
),
1997+
],
1998+
),
19461999
Spec(
19472000
op="ge.Scalar", # (Tensor self, Scalar other) -> Tensor
19482001
inspec=[

0 commit comments

Comments
 (0)