-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[Offload] Full AMD support for olMemFill #154958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -924,6 +924,7 @@ struct AMDGPUStreamTy { | |
| void *Dst; | ||
| const void *Src; | ||
| size_t Size; | ||
| size_t NumTimes; | ||
| }; | ||
|
|
||
| /// Utility struct holding arguments for freeing buffers to memory managers. | ||
|
|
@@ -974,9 +975,14 @@ struct AMDGPUStreamTy { | |
| StreamSlotTy() : Signal(nullptr), Callbacks({}), ActionArgs({}) {} | ||
|
|
||
| /// Schedule a host memory copy action on the slot. | ||
| Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size) { | ||
| /// | ||
| /// Num times will repeat the copy that many times, sequentually in the dest | ||
| /// buffer. | ||
| Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size, | ||
| size_t NumTimes = 1) { | ||
| Callbacks.emplace_back(memcpyAction); | ||
| ActionArgs.emplace_back().MemcpyArgs = MemcpyArgsTy{Dst, Src, Size}; | ||
| ActionArgs.emplace_back().MemcpyArgs = | ||
| MemcpyArgsTy{Dst, Src, Size, NumTimes}; | ||
| return Plugin::success(); | ||
| } | ||
|
|
||
|
|
@@ -1216,7 +1222,12 @@ struct AMDGPUStreamTy { | |
| assert(Args->Dst && "Invalid destination buffer"); | ||
| assert(Args->Src && "Invalid source buffer"); | ||
|
|
||
| std::memcpy(Args->Dst, Args->Src, Args->Size); | ||
| auto BasePtr = Args->Dst; | ||
| for (size_t I = 0; I < Args->NumTimes; I++) { | ||
| std::memcpy(BasePtr, Args->Src, Args->Size); | ||
| BasePtr = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(BasePtr) + | ||
| Args->Size); | ||
| } | ||
|
|
||
| return Plugin::success(); | ||
| } | ||
|
|
@@ -1421,7 +1432,8 @@ struct AMDGPUStreamTy { | |
| /// manager once the operation completes. | ||
| Error pushMemoryCopyH2DAsync(void *Dst, const void *Src, void *Inter, | ||
| uint64_t CopySize, | ||
| AMDGPUMemoryManagerTy &MemoryManager) { | ||
| AMDGPUMemoryManagerTy &MemoryManager, | ||
| size_t NumTimes = 1) { | ||
| // Retrieve available signals for the operation's outputs. | ||
| AMDGPUSignalTy *OutputSignals[2] = {}; | ||
| if (auto Err = SignalManager.getResources(/*Num=*/2, OutputSignals)) | ||
|
|
@@ -1443,7 +1455,8 @@ struct AMDGPUStreamTy { | |
| // The std::memcpy is done asynchronously using an async handler. We store | ||
| // the function's information in the action but it is not actually a | ||
| // post action. | ||
| if (auto Err = Slots[Curr].schedHostMemoryCopy(Inter, Src, CopySize)) | ||
| if (auto Err = | ||
| Slots[Curr].schedHostMemoryCopy(Inter, Src, CopySize, NumTimes)) | ||
| return Err; | ||
|
|
||
| // Make changes on this slot visible to the async handler's thread. | ||
|
|
@@ -1464,7 +1477,12 @@ struct AMDGPUStreamTy { | |
| std::tie(Curr, InputSignal) = consume(OutputSignal); | ||
| } else { | ||
| // All preceding operations completed, copy the memory synchronously. | ||
| std::memcpy(Inter, Src, CopySize); | ||
| auto *InterPtr = Inter; | ||
| for (size_t I = 0; I < NumTimes; I++) { | ||
| std::memcpy(InterPtr, Src, CopySize); | ||
| InterPtr = reinterpret_cast<void *>( | ||
| reinterpret_cast<uintptr_t>(InterPtr) + CopySize); | ||
| } | ||
|
|
||
| // Return the second signal because it will not be used. | ||
| OutputSignals[1]->decreaseUseCount(); | ||
|
|
@@ -1481,11 +1499,11 @@ struct AMDGPUStreamTy { | |
| if (InputSignal && InputSignal->load()) { | ||
| hsa_signal_t InputSignalRaw = InputSignal->get(); | ||
| return hsa_utils::asyncMemCopy(UseMultipleSdmaEngines, Dst, Agent, Inter, | ||
| Agent, CopySize, 1, &InputSignalRaw, | ||
| OutputSignal->get()); | ||
| Agent, CopySize * NumTimes, 1, | ||
| &InputSignalRaw, OutputSignal->get()); | ||
| } | ||
| return hsa_utils::asyncMemCopy(UseMultipleSdmaEngines, Dst, Agent, Inter, | ||
| Agent, CopySize, 0, nullptr, | ||
| Agent, CopySize * NumTimes, 0, nullptr, | ||
| OutputSignal->get()); | ||
| } | ||
|
|
||
|
|
@@ -2611,26 +2629,73 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { | |
| Error dataFillImpl(void *TgtPtr, const void *PatternPtr, int64_t PatternSize, | ||
| int64_t Size, | ||
| AsyncInfoWrapperTy &AsyncInfoWrapper) override { | ||
| hsa_status_t Status; | ||
| // Fast case, where we can use the 4 byte hsa_amd_memory_fill | ||
| if (Size % 4 == 0 && | ||
| (PatternSize == 4 || PatternSize == 2 || PatternSize == 1)) { | ||
| uint32_t Pattern; | ||
| if (PatternSize == 1) { | ||
| auto *Byte = reinterpret_cast<const uint8_t *>(PatternPtr); | ||
| Pattern = *Byte | *Byte << 8 | *Byte << 16 | *Byte << 24; | ||
| } else if (PatternSize == 2) { | ||
| auto *Word = reinterpret_cast<const uint16_t *>(PatternPtr); | ||
| Pattern = *Word | (*Word << 16); | ||
| } else if (PatternSize == 4) { | ||
| Pattern = *reinterpret_cast<const uint32_t *>(PatternPtr); | ||
| } else { | ||
| // Shouldn't be here if the pattern size is outwith those values | ||
| std::terminate(); | ||
RossBrunton marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| // We can use hsa_amd_memory_fill for this size, but it's not async so the | ||
| // queue needs to be synchronized first | ||
| if (PatternSize == 4) { | ||
| if (AsyncInfoWrapper.hasQueue()) | ||
| if (auto Err = synchronize(AsyncInfoWrapper)) | ||
| if (hasPendingWorkImpl(AsyncInfoWrapper)) { | ||
| AMDGPUStreamTy *Stream = nullptr; | ||
| if (auto Err = getStream(AsyncInfoWrapper, Stream)) | ||
| return Err; | ||
| Status = hsa_amd_memory_fill(TgtPtr, | ||
| *static_cast<const uint32_t *>(PatternPtr), | ||
| Size / PatternSize); | ||
|
|
||
| if (auto Err = | ||
| Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n")) | ||
| return Err; | ||
| } else { | ||
| // TODO: Implement for AMDGPU. Most likely by doing the fill in pinned | ||
| // memory and copying to the device in one go. | ||
| return Plugin::error(ErrorCode::UNSUPPORTED, "Unsupported fill size"); | ||
| struct MemFillArgsTy { | ||
| void *Dst; | ||
| uint32_t Pattern; | ||
| int64_t Size; | ||
| }; | ||
| auto *Args = new MemFillArgsTy{TgtPtr, Pattern, Size / 4}; | ||
| auto Fill = [](void *Data) { | ||
| MemFillArgsTy *Args = reinterpret_cast<MemFillArgsTy *>(Data); | ||
| assert(Args && "Invalid arguments"); | ||
|
|
||
| auto Status = | ||
| hsa_amd_memory_fill(Args->Dst, Args->Pattern, Args->Size); | ||
| delete Args; | ||
| auto Err = | ||
| Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n"); | ||
| if (Err) { | ||
| FATAL_MESSAGE(1, "error performing async fill: %s", | ||
| toString(std::move(Err)).data()); | ||
| } | ||
|
Comment on lines
+2667
to
+2670
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should try as hard as possible not to just roll over and die inside of the plugin. We don't do a great job of it so far.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, but that would require liboffload/PluginInterface to have some kind of async error handling. Maybe Anyway, I think that should be a separate task. What I'm doing here is the same as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's probably something we should try to figure out at some point. I forget if we discussed it before, but we definitely don't want anything like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have discussed it before, that memory is lost to me. What are your particular grievances with
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same reason that I think HSA uses callbacks in a similar way, but in my head it would probably be best if we just made it an event or something on the stream the user can query if needed.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the idea of linking errors to streams (e.g. the stream is put into an error state and I think this is something that warrants its own thoughts and discussion. Can I merge this as is (assuming no other issues) just to unblock AMD and look into a design for error handling as a separate change? |
||
| }; | ||
|
|
||
| // hsa_amd_memory_fill doesn't signal completion using a signal, so use | ||
| // the existing host callback logic to handle that instead | ||
| return Stream->pushHostCallback(Fill, Args); | ||
| } else { | ||
| // If there is no pending work, do the fill synchronously | ||
| auto Status = hsa_amd_memory_fill(TgtPtr, Pattern, Size / 4); | ||
| return Plugin::check(Status, "error in hsa_amd_memory_fill: %s\n"); | ||
| } | ||
| } | ||
|
|
||
| // Slow case; allocate an appropriate memory size and enqueue copies | ||
| void *PinnedPtr = nullptr; | ||
| AMDGPUMemoryManagerTy &PinnedMemoryManager = | ||
| HostDevice.getPinnedMemoryManager(); | ||
| if (auto Err = PinnedMemoryManager.allocate(Size, &PinnedPtr)) | ||
| return Err; | ||
|
|
||
| AMDGPUStreamTy *Stream = nullptr; | ||
| if (auto Err = getStream(AsyncInfoWrapper, Stream)) | ||
| return Err; | ||
|
|
||
| return Stream->pushMemoryCopyH2DAsync(TgtPtr, PatternPtr, PinnedPtr, | ||
| PatternSize, PinnedMemoryManager, | ||
| Size / PatternSize); | ||
| } | ||
|
|
||
| /// Initialize the async info for interoperability purposes. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,40 @@ template <typename Fn> inline void threadify(Fn body) { | |
| } | ||
| } | ||
|
|
||
| /// Enqueues a task to the queue that can be manually resolved. | ||
| // It will block until `trigger` is called. | ||
| struct ManuallyTriggeredTask { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not entirely sure why we need this but I'll leave it be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The AMD driver has two different code paths depending on whether the queue is empty or not. This struct allows the test framework to "hold" a task in the queue making it non-empty so that the non-empty path can be tested. |
||
| std::mutex M; | ||
| std::condition_variable CV; | ||
| bool Flag = false; | ||
| ol_event_handle_t CompleteEvent; | ||
|
|
||
| ol_result_t enqueue(ol_queue_handle_t Queue) { | ||
| if (auto Err = olLaunchHostFunction( | ||
| Queue, | ||
| [](void *That) { | ||
| static_cast<ManuallyTriggeredTask *>(That)->wait(); | ||
| }, | ||
| this)) | ||
| return Err; | ||
|
|
||
| return olCreateEvent(Queue, &CompleteEvent); | ||
| } | ||
|
|
||
| void wait() { | ||
| std::unique_lock<std::mutex> lk(M); | ||
| CV.wait_for(lk, std::chrono::milliseconds(1000), [&] { return Flag; }); | ||
| EXPECT_TRUE(Flag); | ||
| } | ||
|
|
||
| ol_result_t trigger() { | ||
| Flag = true; | ||
| CV.notify_one(); | ||
|
|
||
| return olSyncEvent(CompleteEvent); | ||
| } | ||
| }; | ||
|
|
||
| struct OffloadTest : ::testing::Test { | ||
| ol_device_handle_t Host = TestEnvironment::getHostDevice(); | ||
| }; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.