Skip to content

Commit 0457b63

Browse files
authored
SWDEV-527781 - Remove Stream Validation in HIP APIs
1 parent f7482ef commit 0457b63

File tree

7 files changed

+21
-68
lines changed

7 files changed

+21
-68
lines changed

hipamd/src/hip_event.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,6 @@ hipError_t hipEventRecord_common(hipEvent_t event, hipStream_t stream, unsigned
390390
return hipErrorInvalidHandle;
391391
}
392392
getStreamPerThread(stream);
393-
if (!hip::isValid(stream)) {
394-
return hipErrorContextIsDestroyed;
395-
}
396393
hip::Event* e = reinterpret_cast<hip::Event*>(event);
397394
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
398395
hip::Stream* hip_stream = hip::getStream(stream);

hipamd/src/hip_graph.cpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ inline hipError_t ihipGraphUpload(hipGraphExec_t graphExec, hipStream_t stream)
4242
if (graphExec == nullptr) {
4343
return hipErrorInvalidValue;
4444
}
45-
if (!hip::isValid(stream)) {
46-
return hipErrorContextIsDestroyed;
47-
}
45+
getStreamPerThread(stream);
4846
if (!hip::GraphExec::isGraphExecValid(reinterpret_cast<hip::GraphExec*>(graphExec))) {
4947
return hipErrorInvalidValue;
5048
}
@@ -1026,9 +1024,7 @@ hipError_t hipStreamIsCapturing_common(hipStream_t stream, hipStreamCaptureStatu
10261024
if (pCaptureStatus == nullptr) {
10271025
return hipErrorInvalidValue;
10281026
}
1029-
if (!hip::isValid(stream)) {
1030-
return hipErrorContextIsDestroyed;
1031-
}
1027+
getStreamPerThread(stream);
10321028
if (hip::Stream::StreamCaptureBlocking() == true &&
10331029
(stream == nullptr || stream == hipStreamLegacy)) {
10341030
return hipErrorStreamCaptureImplicit;
@@ -1069,9 +1065,7 @@ hipError_t hipThreadExchangeStreamCaptureMode(hipStreamCaptureMode* mode) {
10691065

10701066
hipError_t hipStreamBeginCapture_common(hipStream_t stream, hipStreamCaptureMode mode,
10711067
hipGraph_t graph = nullptr) {
1072-
if (!hip::isValid(stream)) {
1073-
return hipErrorContextIsDestroyed;
1074-
}
1068+
getStreamPerThread(stream);
10751069
// capture cannot be initiated on legacy stream
10761070
if (stream == nullptr || stream == hipStreamLegacy) {
10771071
return hipErrorStreamCaptureUnsupported;
@@ -1591,9 +1585,7 @@ hipError_t hipGraphExecDestroy(hipGraphExec_t pGraphExec) {
15911585
}
15921586

15931587
hipError_t ihipGraphLaunch(hip::GraphExec* graphExec, hipStream_t stream) {
1594-
if (!hip::isValid(stream)) {
1595-
return hipErrorContextIsDestroyed;
1596-
}
1588+
getStreamPerThread(stream);
15971589
hip::Stream* launch_stream = hip::getStream(stream);
15981590
return graphExec->Run(launch_stream);
15991591
}
@@ -1605,9 +1597,6 @@ hipError_t hipGraphLaunch_common(hip::GraphExec* graphExec, hipStream_t stream)
16051597
if (graphExec->GetNodeCount() == 0) {
16061598
return hipSuccess;
16071599
}
1608-
if (!hip::isValid(stream)) {
1609-
return hipErrorContextIsDestroyed;
1610-
}
16111600
return ihipGraphLaunch(graphExec, stream);
16121601
}
16131602

@@ -2004,9 +1993,7 @@ hipError_t hipStreamGetCaptureInfo_common(hipStream_t stream,
20041993
if (pCaptureStatus == nullptr) {
20051994
return hipErrorInvalidValue;
20061995
}
2007-
if (!hip::isValid(stream)) {
2008-
return hipErrorContextIsDestroyed;
2009-
}
1996+
getStreamPerThread(stream);
20101997
if (hip::Stream::StreamCaptureBlocking() == true &&
20111998
(stream == nullptr || stream == hipStreamLegacy)) {
20121999
return hipErrorStreamCaptureImplicit;

hipamd/src/hip_hmm.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ hipError_t hipMemPrefetchAsync(const void* dev_ptr, size_t count, int device,
8282
HIP_RETURN(hipErrorInvalidValue);
8383
}
8484

85-
if (!hip::isValid(stream)) {
86-
HIP_RETURN(hipErrorContextIsDestroyed);
87-
}
85+
getStreamPerThread(stream);
8886

8987
size_t offset = 0;
9088
amd::Memory* memObj = getMemoryObject(dev_ptr, offset);
@@ -238,9 +236,7 @@ hipError_t hipStreamAttachMemAsync(hipStream_t stream, void* dev_ptr,
238236
HIP_RETURN(hipErrorInvalidValue);
239237
}
240238

241-
if (!hip::isValid(stream)) {
242-
HIP_RETURN(hipErrorContextIsDestroyed);
243-
}
239+
getStreamPerThread(stream);
244240

245241
if (flags != hipMemAttachGlobal && flags != hipMemAttachHost && flags != hipMemAttachSingle) {
246242
HIP_RETURN(hipErrorInvalidValue);

hipamd/src/hip_memory.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,13 +1430,11 @@ hipError_t hipMemcpyAsync_common(void* dst, const void* src, size_t sizeBytes,
14301430
if (static_cast<uint32_t>(kind) > hipMemcpyDefault && kind != hipMemcpyDeviceToDeviceNoCU) {
14311431
return hipErrorInvalidMemcpyDirection;
14321432
}
1433+
getStreamPerThread(stream);
14331434
hip::Stream* hip_stream = hip::getStream(stream);
14341435
if (hip_stream == nullptr) {
14351436
return hipErrorInvalidValue;
14361437
}
1437-
if (!hip::isValid(stream)) {
1438-
return hipErrorContextIsDestroyed;
1439-
}
14401438
return ihipMemcpy(dst, src, sizeBytes, kind, *hip_stream, true);
14411439
}
14421440

@@ -2360,9 +2358,7 @@ hipError_t ihipMemcpyParam3D(const HIP_MEMCPY3D* pCopy, hipStream_t stream, bool
23602358
if (pCopy == nullptr) {
23612359
return hipErrorInvalidValue;
23622360
}
2363-
if (!hip::isValid(stream)) {
2364-
return hipErrorContextIsDestroyed;
2365-
}
2361+
getStreamPerThread(stream);
23662362
hipMemoryType srcMemoryType;
23672363
hipMemoryType dstMemoryType;
23682364
ihipCopyMemParamSet(pCopy, srcMemoryType, dstMemoryType);
@@ -2448,9 +2444,7 @@ hipError_t hipMemcpy2DValidateParams(hipMemcpyKind kind, hipStream_t stream = nu
24482444
return hipErrorInvalidMemcpyDirection;
24492445
}
24502446

2451-
if (!hip::isValid(stream)) {
2452-
return hipErrorInvalidValue;
2453-
}
2447+
getStreamPerThread(stream);
24542448

24552449
return hipSuccess;
24562450
}

hipamd/src/hip_mempool.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ hipError_t hipMallocAsync(void** dev_ptr, size_t size, hipStream_t stream) {
8585
if (dev_ptr == nullptr) {
8686
HIP_RETURN(hipErrorInvalidValue);
8787
}
88-
if (!hip::isValid(stream)) {
89-
HIP_RETURN(hipErrorInvalidHandle);
90-
}
88+
getStreamPerThread(stream);
9189
if (size == 0) {
9290
*dev_ptr = nullptr;
9391
HIP_RETURN(hipSuccess);
@@ -147,9 +145,7 @@ class FreeAsyncCommand : public amd::Command {
147145
hipError_t hipFreeAsync(void* dev_ptr, hipStream_t stream) {
148146
HIP_INIT_API(hipFreeAsync, dev_ptr, stream);
149147

150-
if (!hip::isValid(stream)) {
151-
HIP_RETURN(hipErrorInvalidHandle);
152-
}
148+
getStreamPerThread(stream);
153149

154150
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
155151
auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ?
@@ -376,9 +372,7 @@ hipError_t hipMallocFromPoolAsync(
376372
if ((dev_ptr == nullptr) || (mem_pool == nullptr)) {
377373
HIP_RETURN(hipErrorInvalidValue);
378374
}
379-
if (!hip::isValid(stream)) {
380-
HIP_RETURN(hipErrorInvalidHandle);
381-
}
375+
getStreamPerThread(stream);
382376
if (size == 0) {
383377
*dev_ptr = nullptr;
384378
HIP_RETURN(hipSuccess);

hipamd/src/hip_peer.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,7 @@ hipError_t hipMemcpyPeerAsync(void* dst, int dstDevice, const void* src, int src
209209
srcDevice < 0 || dstDevice < 0) {
210210
HIP_RETURN(hipErrorInvalidDevice);
211211
}
212-
if (!hip::isValid(stream)) {
213-
return hipErrorContextIsDestroyed;
214-
}
212+
getStreamPerThread(stream);
215213
hip::Stream* hip_stream = hip::getStream(stream);
216214
if (hip_stream == nullptr) {
217215
return hipErrorInvalidValue;

hipamd/src/hip_stream.cpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,7 @@ hipError_t hipDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPrio
323323
// ================================================================================================
324324
hipError_t hipStreamGetFlags_common(hipStream_t stream, unsigned int* flags) {
325325
if ((flags != nullptr) && (stream != nullptr)) {
326-
if (!hip::isValid(stream)) {
327-
return hipErrorContextIsDestroyed;
328-
}
326+
getStreamPerThread(stream);
329327
*flags = reinterpret_cast<hip::Stream*>(stream)->Flags();
330328
} else {
331329
return hipErrorInvalidValue;
@@ -349,9 +347,7 @@ hipError_t hipStreamGetFlags_spt(hipStream_t stream, unsigned int* flags) {
349347

350348
// ================================================================================================
351349
hipError_t hipStreamSynchronize_common(hipStream_t stream) {
352-
if (!hip::isValid(stream)) {
353-
HIP_RETURN(hipErrorContextIsDestroyed);
354-
}
350+
getStreamPerThread(stream);
355351
if (stream != nullptr && stream != hipStreamLegacy) {
356352
// If still capturing return error
357353
if (hip::Stream::StreamCaptureOngoing(stream) == true) {
@@ -398,9 +394,6 @@ hipError_t hipStreamDestroy(hipStream_t stream) {
398394
if (stream == hipStreamPerThread || stream == hipStreamLegacy) {
399395
HIP_RETURN(hipErrorInvalidResourceHandle);
400396
}
401-
if (!hip::isValid(stream)) {
402-
HIP_RETURN(hipErrorContextIsDestroyed);
403-
}
404397
hip::Stream* s = reinterpret_cast<hip::Stream*>(stream);
405398
if (s->GetCaptureStatus() != hipStreamCaptureStatusNone) {
406399
if (s->GetParentStream() != nullptr) {
@@ -448,9 +441,10 @@ void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_d
448441
// ================================================================================================
449442
hipError_t hipStreamWaitEvent_common(hipStream_t stream, hipEvent_t event, unsigned int flags) {
450443
hipError_t status = hipSuccess;
451-
if (event == nullptr || !hip::isValid(stream)) {
444+
if (event == nullptr) {
452445
return hipErrorInvalidHandle;
453446
}
447+
getStreamPerThread(stream);
454448
hip::Stream* waitStream = hip::getStream(stream);
455449
hip::Event* e = reinterpret_cast<hip::Event*>(event);
456450
auto eventStreamHandle = reinterpret_cast<hipStream_t>(e->GetCaptureStream());
@@ -511,9 +505,7 @@ hipError_t hipStreamWaitEvent_spt(hipStream_t stream, hipEvent_t event, unsigned
511505

512506
// ================================================================================================
513507
hipError_t hipStreamQuery_common(hipStream_t stream) {
514-
if (!hip::isValid(stream)) {
515-
return hipErrorContextIsDestroyed;
516-
}
508+
getStreamPerThread(stream);
517509
if (stream != nullptr) {
518510
// If still capturing return error
519511
if (hip::Stream::StreamCaptureOngoing(stream) == true) {
@@ -566,10 +558,7 @@ hipError_t hipStreamQuery_spt(hipStream_t stream) {
566558
}
567559

568560
hipError_t streamCallback_common(hipStream_t stream, StreamCallback* cbo, void* userData) {
569-
if (!hip::isValid(stream)) {
570-
return hipErrorContextIsDestroyed;
571-
}
572-
561+
getStreamPerThread(stream);
573562
hip::Stream* hip_stream = hip::getStream(stream);
574563
amd::Command* last_command = hip_stream->getLastQueuedCommand(true);
575564
amd::Command::EventWaitList eventWaitList;
@@ -688,9 +677,7 @@ hipError_t hipStreamGetPriority_common(hipStream_t stream, int* priority) {
688677
}
689678

690679
if ((priority != nullptr) && (stream != nullptr)) {
691-
if (!hip::isValid(stream)) {
692-
return hipErrorContextIsDestroyed;
693-
}
680+
getStreamPerThread(stream);
694681
*priority = static_cast<int>(reinterpret_cast<hip::Stream*>(stream)->GetPriority());
695682
} else {
696683
return hipErrorInvalidValue;

0 commit comments

Comments
 (0)