33// See LICENSE.TXT
44// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55
6+ #include < algorithm>
67#include < cstring>
78#include < fstream>
89
@@ -41,6 +42,23 @@ std::ostream &operator<<(std::ostream &out,
4142 return out;
4243}
4344
45+ std::ostream &operator <<(std::ostream &out, const ur_device_handle_t &device) {
46+ size_t size;
47+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, 0 , nullptr , &size);
48+ std::vector<char > name (size);
49+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, size, name.data (), nullptr );
50+ out << name.data ();
51+ return out;
52+ }
53+
54+ std::ostream &operator <<(std::ostream &out,
55+ const std::vector<ur_device_handle_t > &devices) {
56+ for (auto device : devices) {
57+ out << " \n * \" " << device << " \" " ;
58+ }
59+ return out;
60+ }
61+
4462uur::PlatformEnvironment::PlatformEnvironment (int argc, char **argv)
4563 : platform_options{parsePlatformOptions (argc, argv)} {
4664 instance = this ;
@@ -100,14 +118,16 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
100118 }
101119
102120 if (platform_options.platform_name .empty ()) {
103- if (platforms.size () == 1 ) {
121+
122+ if (platforms.size () == 1 || platform_options.platforms_count == 1 ) {
104123 platform = platforms[0 ];
105124 } else {
106125 std::stringstream ss_error;
107126 ss_error << " Select a single platform from below using the "
108127 " --platform=NAME "
109128 " command-line option:"
110- << platforms;
129+ << platforms << std::endl
130+ << " or set --platforms_count=1." ;
111131 error = ss_error.str ();
112132 return ;
113133 }
@@ -136,7 +156,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
136156 << " \" not found. Select a single platform from below "
137157 " using the "
138158 " --platform=NAME command-line options:"
139- << platforms;
159+ << platforms << std::endl
160+ << " or set --platforms_count=1." ;
140161 error = ss_error.str ();
141162 return ;
142163 }
@@ -177,6 +198,10 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
177198 arg, " --platform=" , sizeof (" --platform=" ) - 1 ) == 0 ) {
178199 options.platform_name =
179200 std::string (&arg[std::strlen (" --platform=" )]);
201+ } else if (std::strncmp (arg, " --platforms_count=" ,
202+ sizeof (" --platforms_count=" ) - 1 ) == 0 ) {
203+ options.platforms_count = std::strtoul (
204+ &arg[std::strlen (" --platforms_count=" )], nullptr , 10 );
180205 }
181206 }
182207
@@ -192,10 +217,31 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
192217 return options;
193218}
194219
220+ DevicesEnvironment::DeviceOptions
221+ DevicesEnvironment::parseDeviceOptions (int argc, char **argv) {
222+ DeviceOptions options;
223+ for (int argi = 1 ; argi < argc; ++argi) {
224+ const char *arg = argv[argi];
225+ if (!(std::strcmp (arg, " -h" ) && std::strcmp (arg, " --help" ))) {
226+ // TODO - print help
227+ break ;
228+ } else if (std::strncmp (arg, " --device=" , sizeof (" --device=" ) - 1 ) ==
229+ 0 ) {
230+ options.device_name = std::string (&arg[std::strlen (" --device=" )]);
231+ } else if (std::strncmp (arg, " --devices_count=" ,
232+ sizeof (" --devices_count=" ) - 1 ) == 0 ) {
233+ options.devices_count = std::strtoul (
234+ &arg[std::strlen (" --devices_count=" )], nullptr , 10 );
235+ }
236+ }
237+ return options;
238+ }
239+
195240DevicesEnvironment *DevicesEnvironment::instance = nullptr ;
196241
197242DevicesEnvironment::DevicesEnvironment (int argc, char **argv)
198- : PlatformEnvironment(argc, argv) {
243+ : PlatformEnvironment(argc, argv),
244+ device_options (parseDeviceOptions(argc, argv)) {
199245 instance = this ;
200246 if (!error.empty ()) {
201247 return ;
@@ -209,11 +255,64 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
209255 error = " Could not find any devices associated with the platform" ;
210256 return ;
211257 }
212- devices.resize (count);
213- if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
214- nullptr )) {
215- error = " urDeviceGet() failed to get devices." ;
216- return ;
258+
259+ // Get the argument (devices_count) to limit test devices count.
260+ // In case, the devices_count is "0", the variable count will not be changed.
261+ // The CTS will run on all devices.
262+ if (device_options.device_name .empty ()) {
263+ if (device_options.devices_count >
264+ (std::numeric_limits<uint32_t >::max)()) {
265+ error = " Invalid devices_count argument" ;
266+ return ;
267+ } else if (device_options.devices_count > 0 ) {
268+ count = (std::min)(
269+ count, static_cast <uint32_t >(device_options.devices_count ));
270+ }
271+ devices.resize (count);
272+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
273+ nullptr )) {
274+ error = " urDeviceGet() failed to get devices." ;
275+ return ;
276+ }
277+ } else {
278+ devices.resize (count);
279+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
280+ nullptr )) {
281+ error = " urDeviceGet() failed to get devices." ;
282+ return ;
283+ }
284+ for (u_long i = 0 ; i < count; i++) {
285+ size_t size;
286+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, 0 , nullptr ,
287+ &size)) {
288+ error = " urDeviceGetInfo() failed" ;
289+ return ;
290+ }
291+ std::vector<char > device_name (size);
292+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, size,
293+ device_name.data (), nullptr )) {
294+ error = " urDeviceGetInfo() failed" ;
295+ return ;
296+ }
297+ if (device_options.device_name == device_name.data ()) {
298+ device = devices[i];
299+ devices.clear ();
300+ devices.resize (1 );
301+ devices[0 ] = device;
302+ break ;
303+ }
304+ }
305+ if (!device) {
306+ std::stringstream ss_error;
307+ ss_error << " Device \" " << device_options.device_name
308+ << " \" not found. Select a single device from below "
309+ " using the "
310+ " --device=NAME command-line options:"
311+ << devices << std::endl
312+ << " or set --devices_count=COUNT." ;
313+ error = ss_error.str ();
314+ return ;
315+ }
217316 }
218317}
219318
0 commit comments