Skip to content

Commit 21c11da

Browse files
hipuddingpytorchmergebot
authored andcommitted
Improve OpenReg test coverage (pytorch#167819)
- add failure-path tests for device, stream, memory, event APIs - cover async memcpy, pointer attributes, event timing, addTask errors - verified via cmake --build build && ctest --test-dir build Pull Request resolved: pytorch#167819 Approved by: https://github.com/fffrog
1 parent 7ffa511 commit 21c11da

File tree

4 files changed

+238
-0
lines changed

4 files changed

+238
-0
lines changed

test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@ TEST_F(DeviceTest, GetDeviceCountValid) {
1616
EXPECT_EQ(count, 2);
1717
}
1818

19+
TEST_F(DeviceTest, GetDeviceCountNullptr) {
20+
// orGetDeviceCount should reject null output pointers.
21+
EXPECT_EQ(orGetDeviceCount(nullptr), orErrorUnknown);
22+
}
23+
1924
TEST_F(DeviceTest, GetDeviceValid) {
2025
int device = -1;
2126
EXPECT_EQ(orGetDevice(&device), orSuccess);
2227
EXPECT_EQ(device, 0);
2328
}
2429

30+
TEST_F(DeviceTest, GetDeviceNullptr) {
31+
// Defensive path: null output pointer must return an error.
32+
EXPECT_EQ(orGetDevice(nullptr), orErrorUnknown);
33+
}
34+
2535
TEST_F(DeviceTest, SetDeviceValid) {
2636
EXPECT_EQ(orSetDevice(1), orSuccess);
2737

@@ -38,4 +48,9 @@ TEST_F(DeviceTest, SetDeviceInvalidNegative) {
3848
EXPECT_EQ(orSetDevice(-1), orErrorUnknown);
3949
}
4050

51+
TEST_F(DeviceTest, SetDeviceInvalidTooLarge) {
52+
// Device indices are 0-based and strictly less than DEVICE_COUNT (2).
53+
EXPECT_EQ(orSetDevice(2), orErrorUnknown);
54+
}
55+
4156
} // namespace

test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ TEST_F(EventTest, EventCreateWithFlagsTiming) {
2929
EXPECT_EQ(orEventDestroy(event), orSuccess);
3030
}
3131

32+
TEST_F(EventTest, EventCreationNullptr) {
33+
// Creation APIs must fail fast on null handles to mirror CUDA semantics.
34+
EXPECT_EQ(orEventCreate(nullptr), orErrorUnknown);
35+
EXPECT_EQ(
36+
orEventCreateWithFlags(nullptr, orEventEnableTiming), orErrorUnknown);
37+
}
38+
3239
TEST_F(EventTest, EventRecordAndSynchronize) {
3340
orStream_t stream = nullptr;
3441
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
@@ -44,6 +51,23 @@ TEST_F(EventTest, EventRecordAndSynchronize) {
4451
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
4552
}
4653

54+
TEST_F(EventTest, EventRecordInvalidArgs) {
55+
orEvent_t event = nullptr;
56+
EXPECT_EQ(orEventCreate(&event), orSuccess);
57+
58+
orStream_t stream = nullptr;
59+
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
60+
61+
// Record/sync/destroy should validate both stream and event pointers.
62+
EXPECT_EQ(orEventRecord(nullptr, stream), orErrorUnknown);
63+
EXPECT_EQ(orEventRecord(event, nullptr), orErrorUnknown);
64+
EXPECT_EQ(orEventSynchronize(nullptr), orErrorUnknown);
65+
EXPECT_EQ(orEventDestroy(nullptr), orErrorUnknown);
66+
67+
EXPECT_EQ(orEventDestroy(event), orSuccess);
68+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
69+
}
70+
4771
TEST_F(EventTest, EventElapsedTime) {
4872
orStream_t stream = nullptr;
4973
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
@@ -70,6 +94,60 @@ TEST_F(EventTest, EventElapsedTime) {
7094
EXPECT_EQ(orEventDestroy(end), orSuccess);
7195
}
7296

97+
// TODO: recording events to a stream is not allowed
98+
// if the stream and the event are not on the same device
99+
// Uncomment this test case after the issue is fixed.
100+
// see #167819
101+
TEST_F(EventTest, DISABLED_EventElapsedTimeDifferentDevicesFails) {
102+
orStream_t stream = nullptr;
103+
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
104+
105+
orEvent_t start = nullptr;
106+
orEvent_t end = nullptr;
107+
EXPECT_EQ(orEventCreateWithFlags(&start, orEventEnableTiming), orSuccess);
108+
109+
EXPECT_EQ(orEventRecord(start, stream), orSuccess);
110+
111+
// Switch device before creating the end event to force a mismatch.
112+
EXPECT_EQ(orSetDevice(1), orSuccess);
113+
EXPECT_EQ(orEventCreateWithFlags(&end, orEventEnableTiming), orSuccess);
114+
EXPECT_EQ(orSetDevice(0), orSuccess);
115+
116+
EXPECT_EQ(orEventRecord(end, stream), orSuccess);
117+
EXPECT_EQ(orEventSynchronize(start), orSuccess);
118+
EXPECT_EQ(orEventSynchronize(end), orSuccess);
119+
120+
float elapsed_ms = 0.0f;
121+
EXPECT_EQ(orEventElapsedTime(&elapsed_ms, start, end), orErrorUnknown);
122+
123+
EXPECT_EQ(orEventDestroy(start), orSuccess);
124+
EXPECT_EQ(orEventDestroy(end), orSuccess);
125+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
126+
}
127+
128+
TEST_F(EventTest, EventElapsedTimeRequiresTimingFlag) {
129+
orStream_t stream = nullptr;
130+
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
131+
132+
orEvent_t start = nullptr;
133+
orEvent_t end = nullptr;
134+
EXPECT_EQ(orEventCreate(&start), orSuccess);
135+
EXPECT_EQ(orEventCreate(&end), orSuccess);
136+
137+
EXPECT_EQ(orEventRecord(start, stream), orSuccess);
138+
EXPECT_EQ(orEventRecord(end, stream), orSuccess);
139+
EXPECT_EQ(orEventSynchronize(start), orSuccess);
140+
EXPECT_EQ(orEventSynchronize(end), orSuccess);
141+
142+
// Without timing-enabled events, querying elapsed time must fail.
143+
float elapsed_ms = 0.0f;
144+
EXPECT_EQ(orEventElapsedTime(&elapsed_ms, start, end), orErrorUnknown);
145+
146+
EXPECT_EQ(orEventDestroy(start), orSuccess);
147+
EXPECT_EQ(orEventDestroy(end), orSuccess);
148+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
149+
}
150+
73151
TEST_F(EventTest, StreamWaitEvent) {
74152
orStream_t stream = nullptr;
75153
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
@@ -85,4 +163,19 @@ TEST_F(EventTest, StreamWaitEvent) {
85163
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
86164
}
87165

166+
TEST_F(EventTest, StreamWaitEventInvalidArgs) {
167+
orStream_t stream = nullptr;
168+
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
169+
170+
orEvent_t event = nullptr;
171+
EXPECT_EQ(orEventCreate(&event), orSuccess);
172+
173+
// Validate both stream and event inputs for wait calls.
174+
EXPECT_EQ(orStreamWaitEvent(nullptr, event, 0), orErrorUnknown);
175+
EXPECT_EQ(orStreamWaitEvent(stream, nullptr, 0), orErrorUnknown);
176+
177+
EXPECT_EQ(orEventDestroy(event), orSuccess);
178+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
179+
}
180+
88181
} // namespace

test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ TEST_F(MemoryTest, AllocateAndFreeHost) {
2626
EXPECT_EQ(orFreeHost(ptr), orSuccess);
2727
}
2828

29+
TEST_F(MemoryTest, FreeNullptrIsNoop) {
30+
// Freeing a nullptr should behave like CUDA: treated as a no-op success.
31+
EXPECT_EQ(orFree(nullptr), orSuccess);
32+
EXPECT_EQ(orFreeHost(nullptr), orSuccess);
33+
}
34+
2935
TEST_F(MemoryTest, AllocateNullptr) {
3036
EXPECT_EQ(orMalloc(nullptr, 4096), orErrorUnknown);
3137
EXPECT_EQ(orMallocHost(nullptr, 4096), orErrorUnknown);
@@ -86,6 +92,48 @@ TEST_F(MemoryTest, MemcpyInvalidKind) {
8692
EXPECT_EQ(orFree(dev_ptr), orSuccess);
8793
}
8894

95+
TEST_F(MemoryTest, MemcpyInvalidCombinations) {
96+
void *dev_src = nullptr, *dev_dst = nullptr;
97+
EXPECT_EQ(orMalloc(&dev_src, 8), orSuccess);
98+
EXPECT_EQ(orMalloc(&dev_dst, 8), orSuccess);
99+
100+
char host_buf[8] = {};
101+
102+
// Deliberately pass mismatched kinds to ensure validation coverage.
103+
EXPECT_EQ(
104+
orMemcpy(host_buf, dev_src, 4, orMemcpyHostToDevice), orErrorUnknown);
105+
EXPECT_EQ(
106+
orMemcpy(dev_dst, host_buf, 4, orMemcpyDeviceToHost), orErrorUnknown);
107+
EXPECT_EQ(
108+
orMemcpy(dev_dst, dev_src, 4, orMemcpyHostToDevice), orErrorUnknown);
109+
110+
EXPECT_EQ(orFree(dev_src), orSuccess);
111+
EXPECT_EQ(orFree(dev_dst), orSuccess);
112+
}
113+
114+
TEST_F(MemoryTest, MemcpyAsyncHostToDevice) {
115+
orStream_t stream = nullptr;
116+
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
117+
118+
const char host_src[] = "async";
119+
char host_dst[6] = {};
120+
void* dev_ptr = nullptr;
121+
EXPECT_EQ(orMalloc(&dev_ptr, sizeof(host_src)), orSuccess);
122+
123+
// Async copies should complete once the stream is synchronized.
124+
EXPECT_EQ(
125+
orMemcpyAsync(dev_ptr, host_src, sizeof(host_src), orMemcpyHostToDevice, stream),
126+
orSuccess);
127+
EXPECT_EQ(orStreamSynchronize(stream), orSuccess);
128+
EXPECT_EQ(orMemcpy(
129+
host_dst, dev_ptr, sizeof(host_src), orMemcpyDeviceToHost),
130+
orSuccess);
131+
EXPECT_STREQ(host_dst, host_src);
132+
133+
EXPECT_EQ(orFree(dev_ptr), orSuccess);
134+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
135+
}
136+
89137
TEST_F(MemoryTest, PointerAttributes) {
90138
void* dev_ptr = nullptr;
91139
EXPECT_EQ(orMalloc(&dev_ptr, 32), orSuccess);
@@ -102,6 +150,14 @@ TEST_F(MemoryTest, PointerAttributes) {
102150
EXPECT_EQ(orFree(dev_ptr), orSuccess);
103151
}
104152

153+
TEST_F(MemoryTest, PointerAttributesInvalidArgs) {
154+
// Attribute queries must fail on null inputs to avoid dereferencing.
155+
char buffer[8] = {};
156+
orPointerAttributes attr{};
157+
EXPECT_EQ(orPointerGetAttributes(nullptr, buffer), orErrorUnknown);
158+
EXPECT_EQ(orPointerGetAttributes(&attr, nullptr), orErrorUnknown);
159+
}
160+
105161
TEST_F(MemoryTest, ProtectUnprotectDevice) {
106162
void* dev_ptr = nullptr;
107163
EXPECT_EQ(orMalloc(&dev_ptr, 64), orSuccess);
@@ -112,4 +168,24 @@ TEST_F(MemoryTest, ProtectUnprotectDevice) {
112168
EXPECT_EQ(orFree(dev_ptr), orSuccess);
113169
}
114170

171+
TEST_F(MemoryTest, ProtectReferenceCounting) {
172+
void* dev_ptr = nullptr;
173+
EXPECT_EQ(orMalloc(&dev_ptr, 64), orSuccess);
174+
175+
// Call unprotect/protect twice to exercise the refcount transitions.
176+
EXPECT_EQ(orMemoryUnprotect(dev_ptr), orSuccess);
177+
EXPECT_EQ(orMemoryUnprotect(dev_ptr), orSuccess);
178+
EXPECT_EQ(orMemoryProtect(dev_ptr), orSuccess);
179+
EXPECT_EQ(orMemoryProtect(dev_ptr), orSuccess);
180+
181+
EXPECT_EQ(orFree(dev_ptr), orSuccess);
182+
}
183+
184+
TEST_F(MemoryTest, DoubleFreeFails) {
185+
void* dev_ptr = nullptr;
186+
EXPECT_EQ(orMalloc(&dev_ptr, 32), orSuccess);
187+
EXPECT_EQ(orFree(dev_ptr), orSuccess);
188+
EXPECT_EQ(orFree(dev_ptr), orErrorUnknown);
189+
}
190+
115191
} // namespace

test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ TEST_F(StreamTest, StreamCreateAndDestroy) {
2121
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
2222
}
2323

24+
TEST_F(StreamTest, StreamCreateNullptr) {
25+
// Creation API should reject null double-pointer inputs.
26+
EXPECT_EQ(orStreamCreate(nullptr), orErrorUnknown);
27+
}
28+
2429
TEST_F(StreamTest, StreamCreateWithInvalidPriority) {
2530
orStream_t stream = nullptr;
2631
int min_p, max_p;
@@ -30,6 +35,36 @@ TEST_F(StreamTest, StreamCreateWithInvalidPriority) {
3035
EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, max_p + 1), orErrorUnknown);
3136
}
3237

38+
TEST_F(StreamTest, StreamCreateWithPriorityValidBounds) {
39+
orStream_t stream = nullptr;
40+
int min_p, max_p;
41+
orDeviceGetStreamPriorityRange(&min_p, &max_p);
42+
43+
// Lowest priority should be accepted.
44+
EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, min_p), orSuccess);
45+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
46+
47+
// Highest priority should also be accepted.
48+
EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, max_p), orSuccess);
49+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
50+
}
51+
52+
TEST_F(StreamTest, StreamDestroyNullptr) {
53+
// Destroying nullptr should follow CUDA error behavior.
54+
EXPECT_EQ(orStreamDestroy(nullptr), orErrorUnknown);
55+
}
56+
57+
TEST_F(StreamTest, StreamGetPriority) {
58+
orStream_t stream = nullptr;
59+
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
60+
61+
int priority = -1;
62+
EXPECT_EQ(orStreamGetPriority(stream, &priority), orSuccess);
63+
EXPECT_EQ(priority, 0);
64+
65+
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
66+
}
67+
3368
TEST_F(StreamTest, StreamTaskExecution) {
3469
orStream_t stream = nullptr;
3570
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
@@ -43,6 +78,11 @@ TEST_F(StreamTest, StreamTaskExecution) {
4378
EXPECT_EQ(orStreamDestroy(stream), orSuccess);
4479
}
4580

81+
TEST_F(StreamTest, AddTaskToStreamNullptr) {
82+
// Queueing work should fail fast if the stream handle is invalid.
83+
EXPECT_EQ(openreg::addTaskToStream(nullptr, [] {}), orErrorUnknown);
84+
}
85+
4686
TEST_F(StreamTest, StreamQuery) {
4787
orStream_t stream = nullptr;
4888
EXPECT_EQ(orStreamCreate(&stream), orSuccess);
@@ -76,4 +116,18 @@ TEST_F(StreamTest, DeviceSynchronize) {
76116
EXPECT_EQ(orStreamDestroy(stream2), orSuccess);
77117
}
78118

119+
TEST_F(StreamTest, DeviceSynchronizeWithNoStreams) {
120+
// Even without registered streams, device sync should succeed.
121+
EXPECT_EQ(orDeviceSynchronize(), orSuccess);
122+
}
123+
124+
TEST_F(StreamTest, StreamPriorityRange) {
125+
int min_p = -1;
126+
int max_p = -1;
127+
// OpenReg currently exposes only one priority level; verify the fixed range.
128+
EXPECT_EQ(orDeviceGetStreamPriorityRange(&min_p, &max_p), orSuccess);
129+
EXPECT_EQ(min_p, 0);
130+
EXPECT_EQ(max_p, 0);
131+
}
132+
79133
} // namespace

0 commit comments

Comments
 (0)