Skip to content

Commit 2d80827

Browse files
Fixed op supported dtypes for histc (#1794)
It is UT error. #1791 Add dtypes supported on XPU device for histc op in opinfo class. --------- Signed-off-by: Cheng, Penghui <[email protected]> Co-authored-by: Zhong, Ruijie <[email protected]>
1 parent bb34de0 commit 2d80827

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ jobs:
8181
- name: Check Label
8282
id: check-label
8383
run: |
84-
has_label=$(echo '${{ toJSON(github.event.pull_request.labels) }}' | jq 'any(.name == "windows_ci")')
85-
echo "has_label=$has_label" |tee -a "${GITHUB_OUTPUT}"
84+
echo "has_label=${{ contains(github.event.pull_request.labels.*.name, 'windows_ci') }}" >> $GITHUB_OUTPUT
8685
- name: Check PR infos
8786
id: check-pr-desc
8887
run: |

test/xpu/xpu_test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@
325325
"histogramdd",
326326
]
327327

328+
_ops_dtype_different_cuda_support = {
329+
"histc": {"forward": {torch.bfloat16, torch.float16}},
330+
}
331+
328332
# some case fail in cuda becasue of cuda's bug, so cuda set xfail in opdb
329333
# but xpu can pass these case, and assert 'unexpected success'
330334
# the list will pass these case.
@@ -916,6 +920,12 @@ def align_supported_dtypes(self, db):
916920
backward_dtypes.add(bfloat16)
917921
opinfo.backward_dtypes = tuple(backward_dtypes)
918922

923+
if opinfo.name in _ops_dtype_different_cuda_support:
924+
if "forward" in _ops_dtype_different_cuda_support[opinfo.name]:
925+
opinfo.dtypesIfXPU.update(
926+
_ops_dtype_different_cuda_support[opinfo.name]["forward"]
927+
)
928+
919929
if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
920930
fp64_dtypes = [
921931
torch.float64,

0 commit comments

Comments
 (0)