Skip to content

Commit fd95aba

Browse files
authored
[Compat] Add missing interfaces for PyTorch compat (PaddlePaddle#75874)
1 parent 6ca20eb commit fd95aba

File tree

6 files changed

+92
-3
lines changed

6 files changed

+92
-3
lines changed

paddle/phi/api/include/compat/ATen/cuda/CUDAContext.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,27 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
// #The file has been adapted from pytorch project
16+
// #Licensed under BSD-style license -
17+
// https://github.com/pytorch/pytorch/blob/main/LICENSE
18+
1519
#pragma once
1620

1721
#include <ATen/cuda/Exceptions.h>
22+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1823
#include <c10/cuda/CUDAStream.h>
24+
#include <cuda_runtime_api.h>
25+
#include "paddle/phi/backends/gpu/gpu_info.h"
26+
27+
namespace at::cuda {
28+
cudaDeviceProp* getDeviceProperties(c10::DeviceIndex device) {
29+
return const_cast<cudaDeviceProp*>(
30+
&phi::backends::gpu::GetDeviceProperties(device));
31+
}
32+
33+
cudaDeviceProp* getCurrentDeviceProperties() {
34+
auto device = phi::backends::gpu::GetCurrentDeviceId();
35+
return getDeviceProperties(device);
36+
}
37+
} // namespace at::cuda
38+
#endif

paddle/phi/api/include/compat/c10/util/Exception.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ namespace c10 {
4646
} \
4747
} while (false);
4848

49+
// Check for a given boolean condition.
50+
#define CHECK(condition) PD_CHECK(condition, "CHECK failed : ", #condition)
51+
4952
// TORCH_CHECK_OP macro definitions
5053
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
5154
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// #The file has been adapted from pytorch project
16+
// #Licensed under BSD-style license -
17+
// https://github.com/pytorch/pytorch/blob/main/LICENSE
18+
19+
#pragma once
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// #The file has been adapted from pytorch project
16+
// #Licensed under BSD-style license -
17+
// https://github.com/pytorch/pytorch/blob/main/LICENSE
18+
19+
#pragma once
20+
#include <ATen/Device.h>
21+
#include <c10/util/Exception.h>
22+
#include <torch/types.h>
23+
24+
#if !defined(PADDLE_ON_INFERENCE) && !defined(PADDLE_NO_PYTHON)
25+
// Python bindings for the C++ frontend (includes Python.h)
26+
#include "paddle/utils/pybind.h"
27+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// #The file has been adapted from pytorch project
16+
// #Licensed under BSD-style license -
17+
// https://github.com/pytorch/pytorch/blob/main/LICENSE
18+
19+
#pragma once
20+
21+
#include <torch/all.h>
22+
#include <torch/python.h>

paddle/phi/api/include/compat/utils/macros.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
namespace compat {
1818
#ifndef TORCH_EXTENSION_NAME
19-
#define _EXPAND(x) x
20-
#define TORCH_EXTENSION_NAME _EXPAND(PADDLE_EXTENSION_NAME)
21-
#undef _EXPAND
19+
#define TORCH_EXTENSION_NAME PADDLE_EXTENSION_NAME
2220
#endif
2321
#define UNSUPPORTED_FEATURE_IN_PADDLE(feature) \
2422
std::cerr << "Unsupported feature in Paddle: " << feature << std::endl;

0 commit comments

Comments
 (0)