|
3478 | 3478 | ],
|
3479 | 3479 | outspec=[OutArg(ArgType.Tensor)],
|
3480 | 3480 | ),
|
| 3481 | + Spec( # TODO(mcandales): Calibrate. |
| 3482 | + op="scatter.value", # (Tensor self, int dim, Tensor index, Scalar value) -> Tensor |
| 3483 | + inspec=[ |
| 3484 | + InPosArg(ArgType.Tensor, name="self"), |
| 3485 | + InPosArg( |
| 3486 | + ArgType.Dim, |
| 3487 | + name="dim", |
| 3488 | + deps=[0], |
| 3489 | + constraints=[ |
| 3490 | + cp.Value.In(lambda deps: fn.dim_non_zero_size(deps[0])), |
| 3491 | + ], |
| 3492 | + ), |
| 3493 | + InPosArg( |
| 3494 | + ArgType.Tensor, |
| 3495 | + name="index", |
| 3496 | + deps=[0, 1], |
| 3497 | + # TODO(mcandales) Handle index.numel() == 0 case |
| 3498 | + constraints=[ |
| 3499 | + cp.Dtype.Eq(lambda deps: torch.long), |
| 3500 | + cp.Rank.Eq( |
| 3501 | + lambda deps: deps[0].dim() if deps[0].dim() >= 2 else None |
| 3502 | + ), |
| 3503 | + cp.Rank.In( |
| 3504 | + lambda deps: [0, 1] if deps[0].dim() in [0, 1] else None |
| 3505 | + ), |
| 3506 | + cp.Size.Le( |
| 3507 | + lambda deps, r, d: ( |
| 3508 | + fn.safe_size(deps[0], d) |
| 3509 | + if d != fn.normalize(deps[1], deps[0].dim()) |
| 3510 | + else None |
| 3511 | + ) |
| 3512 | + ), |
| 3513 | + cp.Value.Ge(lambda deps, dtype, struct: 0), |
| 3514 | + cp.Value.Le( |
| 3515 | + lambda deps, dtype, struct: ( |
| 3516 | + 0 |
| 3517 | + if deps[0].dim() == 0 |
| 3518 | + else max(0, fn.safe_size(deps[0], deps[1]) - 1) |
| 3519 | + ) |
| 3520 | + ), |
| 3521 | + ], |
| 3522 | + ), |
| 3523 | + InPosArg( |
| 3524 | + ArgType.Scalar, |
| 3525 | + name="value", |
| 3526 | + deps=[0], |
| 3527 | + constraints=[ |
| 3528 | + cp.Value.NotIn( |
| 3529 | + lambda deps, dtype: ( |
| 3530 | + [float("-inf"), float("inf")] |
| 3531 | + if deps[0].dtype not in dt._floating |
| 3532 | + else None |
| 3533 | + ) |
| 3534 | + ), |
| 3535 | + cp.Value.Ge( |
| 3536 | + lambda deps, dtype: fn.dtype_lower_bound(deps[0].dtype) |
| 3537 | + ), |
| 3538 | + cp.Value.Le( |
| 3539 | + lambda deps, dtype: fn.dtype_upper_bound(deps[0].dtype) |
| 3540 | + ), |
| 3541 | + ], |
| 3542 | + ), |
| 3543 | + ], |
| 3544 | + outspec=[ |
| 3545 | + OutArg( |
| 3546 | + ArgType.Tensor, |
| 3547 | + constraints=[ |
| 3548 | + cp.Dtype.Eq(lambda deps: deps[0].dtype), |
| 3549 | + ], |
| 3550 | + ), |
| 3551 | + ], |
| 3552 | + ), |
3481 | 3553 | Spec( # TODO(mcandales): Calibrate.
|
3482 | 3554 | op="scatter_add.default", # (Tensor self, int dim, Tensor index, Tensor src) -> Tensor
|
3483 | 3555 | inspec=[
|
|
0 commit comments