1010#pragma once
1111
1212#include " common.hpp"
13+ #include " context.hpp"
14+ #include " device.hpp"
1315
1416struct ur_program_handle_t_ : _ur_object {
1517 // ur_program_handle_t_() {}
@@ -68,50 +70,60 @@ struct ur_program_handle_t_ : _ur_object {
6870 // Construct a program in IL.
6971 ur_program_handle_t_ (state St, ur_context_handle_t Context, const void *Input,
7072 size_t Length)
71- : Context{Context}, NativeDevice{nullptr }, NativeProperties{nullptr },
72- OwnZeModule{true }, State{St}, Code{new uint8_t [Length]},
73- CodeLength{Length}, ZeModule{nullptr }, ZeBuildLog{nullptr } {
74- std::memcpy (Code.get (), Input, Length);
73+ : Context{Context}, NativeProperties{nullptr }, OwnZeModule{true },
74+ SpirvCode{new uint8_t [Length]}, SpirvCodeLength{Length},
75+ InteropZeModule{nullptr } {
76+ std::memcpy (SpirvCode.get (), Input, Length);
77+ // All devices have the program in IL state.
78+ for (auto &Device : Context->getDevices ()) {
79+ DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice ];
80+ PerDevData.State = St;
81+ }
7582 }
7683
77- // Construct a program in NATIVE.
84+ // Construct a program in NATIVE for multiple devices .
7885 ur_program_handle_t_ (state St, ur_context_handle_t Context,
79- ur_device_handle_t Device,
86+ const uint32_t NumDevices,
87+ const ur_device_handle_t *Devices,
8088 const ur_program_properties_t *Properties,
81- const void *Input, size_t Length)
82- : Context{Context}, NativeDevice(Device), NativeProperties(Properties),
83- OwnZeModule{true }, State{St}, Code{new uint8_t [Length]},
84- CodeLength{Length}, ZeModule{nullptr }, ZeBuildLog{nullptr } {
85- std::memcpy (Code.get (), Input, Length);
89+ const uint8_t **Inputs, const size_t *Lengths)
90+ : Context{Context}, NativeProperties(Properties), OwnZeModule{true },
91+ InteropZeModule{nullptr } {
92+ for (uint32_t I = 0 ; I < NumDevices; ++I) {
93+ DeviceData &PerDevData = DeviceDataMap[Devices[I]->ZeDevice ];
94+ PerDevData.State = St;
95+ PerDevData.Binary = std::make_pair (
96+ std::unique_ptr<uint8_t []>(new uint8_t [Lengths[I]]), Lengths[I]);
97+ std::memcpy (PerDevData.Binary .first .get (), Inputs[I], Lengths[I]);
98+ }
8699 }
87100
88101 // Construct a program in Exe or Invalid state.
89102 ur_program_handle_t_ (state St, ur_context_handle_t Context,
90103 ze_module_handle_t ZeModule,
91104 ze_module_build_log_handle_t ZeBuildLog)
92- : Context{Context}, NativeDevice{nullptr }, NativeProperties{nullptr },
93- OwnZeModule{true }, State{St}, ZeModule{ZeModule}, ZeBuildLog{
94- ZeBuildLog} {}
105+ : Context{Context}, NativeProperties{nullptr }, OwnZeModule{true },
106+ InteropZeModule{ZeModule} {
107+ for (auto &Device : Context->getDevices ()) {
108+ DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice ];
109+ PerDevData.State = St;
110+ }
111+ }
95112
96113 // Construct a program in Exe state (interop).
97114 ur_program_handle_t_ (state St, ur_context_handle_t Context,
98115 ze_module_handle_t ZeModule, bool OwnZeModule)
99- : Context{Context}, NativeDevice{nullptr }, NativeProperties{nullptr },
100- OwnZeModule{OwnZeModule}, State{St}, ZeModule{ZeModule}, ZeBuildLog{
101- nullptr } {}
102-
103- // Construct a program from native handle
104- ur_program_handle_t_ (state St, ur_context_handle_t Context,
105- ze_module_handle_t ZeModule)
106- : Context{Context}, NativeDevice{nullptr }, NativeProperties{nullptr },
107- OwnZeModule{true }, State{St}, ZeModule{ZeModule}, ZeBuildLog{nullptr } {}
116+ : Context{Context}, NativeProperties{nullptr }, OwnZeModule{OwnZeModule},
117+ InteropZeModule{ZeModule} {
118+ // TODO: Currently it is not possible to understand the device associated
119+ // with provided ZeModule. So we can't set the state on that device to Exe.
120+ }
108121
109122 // Construct a program in Invalid state with a custom error message.
110123 ur_program_handle_t_ (state St, ur_context_handle_t Context,
111124 const std::string &ErrorMessage)
112- : Context{Context}, NativeDevice{nullptr }, NativeProperties{nullptr },
113- OwnZeModule{true }, ErrorMessage{ErrorMessage}, State{St},
114- ZeModule{nullptr }, ZeBuildLog{nullptr } {}
125+ : Context{Context}, NativeProperties{nullptr }, OwnZeModule{true },
126+ ErrorMessage{ErrorMessage}, InteropZeModule{nullptr } {}
115127
116128 ~ur_program_handle_t_ ();
117129 void ur_release_program_resources (bool deletion);
@@ -122,9 +134,6 @@ struct ur_program_handle_t_ : _ur_object {
122134
123135 const ur_context_handle_t Context; // Context of the program.
124136
125- // Device Handle used for the Native Build
126- ur_device_handle_t NativeDevice;
127-
128137 // Properties used for the Native Build
129138 const ur_program_properties_t *NativeProperties;
130139
@@ -136,35 +145,46 @@ struct ur_program_handle_t_ : _ur_object {
136145 // message from a call to urProgramLink.
137146 const std::string ErrorMessage;
138147
139- state State;
140-
141148 // In IL and Object states, this contains the SPIR-V representation of the
142- // module. In Native state, it contains the native code.
143- std::unique_ptr<uint8_t []> Code ; // Array containing raw IL / native code.
144- size_t CodeLength{ 0 }; // Size (bytes) of the array.
149+ // module.
150+ std::unique_ptr<uint8_t []> SpirvCode ; // Array containing raw IL code.
151+ size_t SpirvCodeLength; // Size (bytes) of the array.
145152
146153 // Used only in IL and Object states. Contains the SPIR-V specialization
147154 // constants as a map from the SPIR-V "SpecID" to a buffer that contains the
148155 // associated value. The caller of the PI layer is responsible for
149156 // maintaining the storage of this buffer.
150157 std::unordered_map<uint32_t , const void *> SpecConstants;
151158
152- // Used only in Object state. Contains the build flags from the last call to
153- // urProgramCompile().
154- std::string BuildFlags;
155-
156- // The Level Zero module handle. Used primarily in Exe state.
157- ze_module_handle_t ZeModule{};
158-
159- // Map of L0 Modules created for all the devices for which a UR Program
160- // has been built.
161- std::unordered_map<ze_device_handle_t , ze_module_handle_t > ZeModuleMap;
162-
163- // The Level Zero build log from the last call to zeModuleCreate().
164- ze_module_build_log_handle_t ZeBuildLog{};
159+ // The Level Zero module handle for interoperability.
160+ // This module handle is either initialized with the handle provided to
161+ // interoperability UR API, or with one of the handles after building the
162+ // program. This handle is returned by UR API which allows to get the native
163+ // handle from the program.
164+ // TODO: Currently interoparability UR API does not support multiple devices.
165+ ze_module_handle_t InteropZeModule{};
166+
167+ struct DeviceData {
168+ // Log from the result of building the program for the device using
169+ // zeModuleCreate().
170+ ze_module_build_log_handle_t ZeBuildLog{};
171+
172+ // The Level Zero module handle for the device. Used primarily in Exe state.
173+ ze_module_handle_t ZeModule{};
174+
175+ // In Native state, contains the pair of the binary code for the device and
176+ // its length in bytes.
177+ std::pair<std::unique_ptr<uint8_t []>, size_t > Binary{nullptr , 0 };
178+
179+ // Build flags used for building the program for the device.
180+ // May be different for different devices, for example, if
181+ // urProgramCompileExp was called multiple times with different build flags
182+ // for different devices.
183+ std::string BuildFlags{};
184+
185+ // State of the program for the device.
186+ state State{};
187+ };
165188
166- // Map of L0 Module Build logs created for all the devices for which a UR
167- // Program has been built.
168- std::unordered_map<ze_device_handle_t , ze_module_build_log_handle_t >
169- ZeBuildLogMap;
189+ std::unordered_map<ze_device_handle_t , DeviceData> DeviceDataMap;
170190};
0 commit comments