@@ -391,3 +391,68 @@ def compilation_info():
391391 assert compilation_info is not None
392392 assert compilation_info .lowering_config == lowering_config
393393 assert compilation_info .translation_info == translation_info
394+
395+
396+ @run
397+ def gpu_target_info_attribute_parsing ():
398+ mlir_string = """
399+ hal.executable private @main_dispatch_0 {
400+ hal.executable.variant public @rocm_hsaco_fb
401+ target(<"rocm", "rocm-hsaco-fb",
402+ {
403+ abi = "hip",
404+ iree_codegen.target_info = #iree_gpu.target<
405+ arch = "gfx942",
406+ features = "",
407+ wgp = <
408+ compute = fp64,
409+ storage = b64,
410+ subgroup = none,
411+ dot = none,
412+ mma = [<MFMA_F32_16x16x4_F32>],
413+ subgroup_size_choices = [32, 64],
414+ max_workgroup_sizes = [256, 512, 1024],
415+ max_thread_count_per_workgroup = 1024,
416+ max_workgroup_memory_bytes = 65536,
417+ max_workgroup_counts = [256, 512, 1024]
418+ >
419+ >
420+ }>
421+ ) {
422+ }
423+ }
424+ """
425+
426+ module = ir .Module .parse (mlir_string )
427+ variant_op_list = iree_codegen .get_executable_variant_ops (module )
428+ assert len (variant_op_list ) == 1 , "Expect one executable variant op"
429+ variant_op = variant_op_list [0 ]
430+ executable_variant_op = variant_op .opview
431+ target = executable_variant_op .target
432+ gpu_target_info = iree_gpu .get_gpu_target_info (target )
433+
434+ arch = gpu_target_info .arch
435+ assert arch == "gfx942" , f"Expected arch 'gfx942', got '{ arch } '"
436+
437+ subgroup_size_choices = gpu_target_info .subgroup_size_choices
438+ assert subgroup_size_choices == [
439+ 32 ,
440+ 64 ,
441+ ], f"Expected subgroup_size_choice [32, 64], got { subgroup_size_choices } "
442+
443+ max_thread_count = gpu_target_info .max_thread_count_per_workgroup
444+ assert (
445+ max_thread_count == 1024
446+ ), f"Expected max_thread_count_per_workgroup 1024, got { max_thread_count } "
447+
448+ max_memory_bytes = gpu_target_info .max_workgroup_memory_bytes
449+ assert (
450+ max_memory_bytes == 65536
451+ ), f"Expected max_workgroup_memory_bytes 65536, got { max_memory_bytes } "
452+
453+ max_workgroup_sizes = gpu_target_info .max_workgroup_sizes
454+ assert max_workgroup_sizes == [
455+ 256 ,
456+ 512 ,
457+ 1024 ,
458+ ], f"Expected max_workgroup_sizes [256, 512, 1024], got { max_workgroup_sizes } "
0 commit comments