@@ -42,6 +42,23 @@ std::ostream &operator<<(std::ostream &out,
4242 return out;
4343}
4444
45+ std::ostream &operator <<(std::ostream &out, const ur_device_handle_t &device) {
46+ size_t size;
47+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, 0 , nullptr , &size);
48+ std::vector<char > name (size);
49+ urDeviceGetInfo (device, UR_DEVICE_INFO_NAME, size, name.data (), nullptr );
50+ out << name.data ();
51+ return out;
52+ }
53+
54+ std::ostream &operator <<(std::ostream &out,
55+ const std::vector<ur_device_handle_t > &devices) {
56+ for (auto device : devices) {
57+ out << " \n * \" " << device << " \" " ;
58+ }
59+ return out;
60+ }
61+
4562uur::PlatformEnvironment::PlatformEnvironment (int argc, char **argv)
4663 : platform_options{parsePlatformOptions (argc, argv)} {
4764 instance = this ;
@@ -101,14 +118,16 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
101118 }
102119
103120 if (platform_options.platform_name .empty ()) {
104- if (platforms.size () == 1 ) {
121+
122+ if (platforms.size () == 1 || platform_options.platforms_count == 1 ) {
105123 platform = platforms[0 ];
106124 } else {
107125 std::stringstream ss_error;
108126 ss_error << " Select a single platform from below using the "
109127 " --platform=NAME "
110128 " command-line option:"
111- << platforms;
129+ << platforms << std::endl
130+ << " or set --platforms_count=1." ;
112131 error = ss_error.str ();
113132 return ;
114133 }
@@ -137,7 +156,8 @@ uur::PlatformEnvironment::PlatformEnvironment(int argc, char **argv)
137156 << " \" not found. Select a single platform from below "
138157 " using the "
139158 " --platform=NAME command-line options:"
140- << platforms;
159+ << platforms << std::endl
160+ << " or set --platforms_count=1." ;
141161 error = ss_error.str ();
142162 return ;
143163 }
@@ -178,6 +198,10 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
178198 arg, " --platform=" , sizeof (" --platform=" ) - 1 ) == 0 ) {
179199 options.platform_name =
180200 std::string (&arg[std::strlen (" --platform=" )]);
201+ } else if (std::strncmp (arg, " --platforms_count=" ,
202+ sizeof (" --platforms_count=" ) - 1 ) == 0 ) {
203+ options.platforms_count = std::strtoul (
204+ &arg[std::strlen (" --platforms_count=" )], nullptr , 10 );
181205 }
182206 }
183207
@@ -193,10 +217,31 @@ PlatformEnvironment::parsePlatformOptions(int argc, char **argv) {
193217 return options;
194218}
195219
220+ DevicesEnvironment::DeviceOptions
221+ DevicesEnvironment::parseDeviceOptions (int argc, char **argv) {
222+ DeviceOptions options;
223+ for (int argi = 1 ; argi < argc; ++argi) {
224+ const char *arg = argv[argi];
225+ if (!(std::strcmp (arg, " -h" ) && std::strcmp (arg, " --help" ))) {
226+ // TODO - print help
227+ break ;
228+ } else if (std::strncmp (arg, " --device=" , sizeof (" --device=" ) - 1 ) ==
229+ 0 ) {
230+ options.device_name = std::string (&arg[std::strlen (" --device=" )]);
231+ } else if (std::strncmp (arg, " --devices_count=" ,
232+ sizeof (" --devices_count=" ) - 1 ) == 0 ) {
233+ options.devices_count = std::strtoul (
234+ &arg[std::strlen (" --devices_count=" )], nullptr , 10 );
235+ }
236+ }
237+ return options;
238+ }
239+
196240DevicesEnvironment *DevicesEnvironment::instance = nullptr ;
197241
198242DevicesEnvironment::DevicesEnvironment (int argc, char **argv)
199- : PlatformEnvironment(argc, argv) {
243+ : PlatformEnvironment(argc, argv),
244+ device_options (parseDeviceOptions(argc, argv)) {
200245 instance = this ;
201246 if (!error.empty ()) {
202247 return ;
@@ -210,27 +255,64 @@ DevicesEnvironment::DevicesEnvironment(int argc, char **argv)
210255 error = " Could not find any devices associated with the platform" ;
211256 return ;
212257 }
213- // Get the argument (test_devices_count) to limit test devices count.
214- u_long count_set = 0 ;
215- for (int i = 1 ; i < argc; ++i) {
216- if (std::strcmp (argv[i], " --test_devices_count" ) == 0 && i + 1 < argc) {
217- count_set = std::strtoul (argv[i + 1 ], nullptr , 10 );
218- break ;
219- }
220- }
221- // In case, the count_set is "0", the variable count will not be changed.
258+
259+ // Get the argument (devices_count) to limit test devices count.
260+ // In case, the devices_count is "0", the variable count will not be changed.
222261 // The CTS will run on all devices.
223- if (count_set > (std::numeric_limits<uint32_t >::max)()) {
224- error = " Invalid test_devices_count argument" ;
225- return ;
226- } else if (count_set > 0 ) {
227- count = (std::min)(count, static_cast <uint32_t >(count_set));
228- }
229- devices.resize (count);
230- if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
231- nullptr )) {
232- error = " urDeviceGet() failed to get devices." ;
233- return ;
262+ if (device_options.device_name .empty ()) {
263+ if (device_options.devices_count >
264+ (std::numeric_limits<uint32_t >::max)()) {
265+ error = " Invalid devices_count argument" ;
266+ return ;
267+ } else if (device_options.devices_count > 0 ) {
268+ count = (std::min)(
269+ count, static_cast <uint32_t >(device_options.devices_count ));
270+ }
271+ devices.resize (count);
272+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
273+ nullptr )) {
274+ error = " urDeviceGet() failed to get devices." ;
275+ return ;
276+ }
277+ } else {
278+ devices.resize (count);
279+ if (urDeviceGet (platform, UR_DEVICE_TYPE_ALL, count, devices.data (),
280+ nullptr )) {
281+ error = " urDeviceGet() failed to get devices." ;
282+ return ;
283+ }
284+ for (u_long i = 0 ; i < count; i++) {
285+ size_t size;
286+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, 0 , nullptr ,
287+ &size)) {
288+ error = " urDeviceGetInfo() failed" ;
289+ return ;
290+ }
291+ std::vector<char > device_name (size);
292+ if (urDeviceGetInfo (devices[i], UR_DEVICE_INFO_NAME, size,
293+ device_name.data (), nullptr )) {
294+ error = " urDeviceGetInfo() failed" ;
295+ return ;
296+ }
297+ if (device_options.device_name == device_name.data ()) {
298+ device = devices[i];
299+ devices.clear ();
300+ devices.resize (1 );
301+ devices[0 ] = device;
302+ break ;
303+ }
304+ }
305+ if (!device) {
306+ std::stringstream ss_error;
307+ ss_error << " Device \" " << device_options.device_name
308+ << " \" not found. Select a single device from below "
309+ " using the "
310+ " --device=NAME command-line options:"
311+ << devices << std::endl
312+ << " or set --devices_count=COUNT." ;
313+ error = ss_error.str ();
314+ return ;
315+ }
234316 }
235317}
236318
0 commit comments