Skip to content

Commit 43d5bd6

Browse files
Echo-Nieluotao1
andauthored
【Hackathon 9th No.70】supplementary unit test for CPUPlatform and CUDAPlatform (#3580)
* 功能模块 CUDAPlatform、CPUPlatform 单测补充 * update the "is_cuda" to "is_cuda_and_available" * fix pre-commit --------- Co-authored-by: Tao Luo <[email protected]>
1 parent 72094d4 commit 43d5bd6

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

tests/platforms/test_platforms.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import unittest
18+
from unittest.mock import patch
19+
20+
from fastdeploy.platforms.base import _Backend
21+
from fastdeploy.platforms.cpu import CPUPlatform
22+
from fastdeploy.platforms.cuda import CUDAPlatform
23+
24+
25+
class TestCPUPlatform(unittest.TestCase):
26+
def setUp(self):
27+
self.platform = CPUPlatform()
28+
29+
@patch("paddle.device.get_device", return_value="cpu")
30+
def test_is_cpu_and_available(self, mock_get_device):
31+
"""
32+
Check hardware type (CPU) and availability
33+
"""
34+
self.assertTrue(self.platform.is_cpu())
35+
self.assertTrue(self.platform.available())
36+
37+
def test_attention_backend(self):
38+
"""CPUPlatform attention_backend should return empty string"""
39+
self.assertEqual(self.platform.get_attention_backend_cls(None), "")
40+
41+
42+
class TestCUDAPlatform(unittest.TestCase):
43+
def setUp(self):
44+
self.platform = CUDAPlatform()
45+
46+
@patch("paddle.is_compiled_with_cuda", return_value=True)
47+
@patch("paddle.device.get_device", return_value="cuda")
48+
@patch("paddle.static.cuda_places", return_value=[0])
49+
def test_is_cuda_and_available(self, mock_get_device, mock_is_cuda, mock_cuda_places):
50+
"""
51+
Check hardware type (CUDA) and availability
52+
"""
53+
self.assertTrue(self.platform.is_cuda())
54+
self.assertTrue(self.platform.available())
55+
56+
def test_attention_backend_valid(self):
57+
"""
58+
CUDAPlatform should return correct backend class name for valid backends
59+
"""
60+
self.assertIn(
61+
"PaddleNativeAttnBackend",
62+
self.platform.get_attention_backend_cls(_Backend.NATIVE_ATTN),
63+
)
64+
self.assertIn(
65+
"AppendAttentionBackend",
66+
self.platform.get_attention_backend_cls(_Backend.APPEND_ATTN),
67+
)
68+
self.assertIn(
69+
"MLAAttentionBackend",
70+
self.platform.get_attention_backend_cls(_Backend.MLA_ATTN),
71+
)
72+
self.assertIn(
73+
"FlashAttentionBackend",
74+
self.platform.get_attention_backend_cls(_Backend.FLASH_ATTN),
75+
)
76+
77+
def test_attention_backend_invalid(self):
78+
"""
79+
CUDAPlatform should raise ValueError for invalid backend
80+
"""
81+
with self.assertRaises(ValueError):
82+
self.platform.get_attention_backend_cls("INVALID_BACKEND")
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)