Skip to content

Commit c096ce1

Browse files
[API Compatibility] Enhance several tensor creation methods (#74477)
* fix index_elemwentwise_get_gard bug slice-check * enhance Tensor creation methods * add static test * fix UT * fix date * refine code * fix * fix UT * fix * fix BatchNormDoubleGradKernel * restore code * fix * fix * fix * fix for review * restore requires_grad setting * fix name * use full instead of fill_constant * refine device * fix * fix string device --------- Co-authored-by: zhanghonggeng <[email protected]>
1 parent b9d9ef7 commit c096ce1

File tree

6 files changed

+567
-108
lines changed

6 files changed

+567
-108
lines changed

python/paddle/device/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,10 @@ def get_cudnn_version() -> int | None:
217217
return _cudnn_version
218218

219219

220-
def _convert_to_place(device: str) -> PlaceLike:
220+
def _convert_to_place(device: PlaceLike) -> PlaceLike:
221+
if not isinstance(device, str):
222+
return device # return directly if not a string
223+
221224
lower_device = device.lower()
222225
if device in core.get_all_custom_device_type():
223226
selected_devices = os.getenv(f"FLAGS_selected_{device}s", "0").split(

0 commit comments

Comments
 (0)