diff --git a/cmake/anydsl_runtime-config.cmake.in b/cmake/anydsl_runtime-config.cmake.in index 97998ea0..6033e9a6 100644 --- a/cmake/anydsl_runtime-config.cmake.in +++ b/cmake/anydsl_runtime-config.cmake.in @@ -277,6 +277,7 @@ function(anydsl_runtime_wrap outfiles) ${AnyDSL_runtime_ROOT_DIR}/platforms/${_frontend}/intrinsics_amdgpu.impala ${AnyDSL_runtime_ROOT_DIR}/platforms/${_frontend}/intrinsics_opencl.impala ${AnyDSL_runtime_ROOT_DIR}/platforms/${_frontend}/intrinsics_thorin.impala + ${AnyDSL_runtime_ROOT_DIR}/platforms/${_frontend}/intrinsics_vulkan.impala ${AnyDSL_runtime_ROOT_DIR}/platforms/${_frontend}/runtime.impala ${_additional_platform_files}) diff --git a/platforms/artic/intrinsics_spirv.impala b/platforms/artic/intrinsics_spirv.impala index 570ff470..793f2da2 100644 --- a/platforms/artic/intrinsics_spirv.impala +++ b/platforms/artic/intrinsics_spirv.impala @@ -1 +1,4 @@ -#[import(cc = "device", name = "spirv.builtin")] fn spirv_get_builtin[T](i32) -> T; \ No newline at end of file +#[import(cc = "device", name = "spirv.builtin")] fn spirv_get_builtin[T](i32) -> T; + +#[import(cc = "device", name = "spirv.global")] fn spirv_make_global_variable[T]() -> T; +#[import(cc = "device", name = "spirv.decorate")] fn spirv_decorate_literal[T](T, u32, u32) -> (); diff --git a/platforms/artic/intrinsics_thorin.impala b/platforms/artic/intrinsics_thorin.impala index c91210a1..276d8eaa 100644 --- a/platforms/artic/intrinsics_thorin.impala +++ b/platforms/artic/intrinsics_thorin.impala @@ -19,6 +19,17 @@ #[import(cc = "thorin")] fn amdgpu_hsa(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> (); #[import(cc = "thorin")] fn amdgpu_pal(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> (); #[import(cc = "thorin")] fn levelzero(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> (); +#[import(cc = "thorin")] fn vulkan_cs(_dev: i32, _grid: (i32, i32, i32), _block: (i32, i32, i32), _body: fn() -> ()) -> (); + +struct VulkanOffloadInfo { + // talks to the runtime to setup this pipeline for execution + setup_offloaded_args: fn() -> (), + filename: &[u8], + num_stages: u32, + stages: &[(u32, &[u8])] +} + +//#[import(cc = "thorin")] fn vulkan_offload(_num_stages: u32, stages: &[(u32, fn() -> ())]) -> VulkanOffloadInfo; #[import(cc = "thorin")] fn reserve_shared[T](_size: i32) -> &mut addrspace(3)[T]; #[import(cc = "thorin")] fn hls(_dev: i32, _body: fn() -> ()) -> (); #[import(cc = "thorin", name = "pipeline")] fn thorin_pipeline(_initiation_interval: i32, _lower: i32, _upper: i32, _body: fn(i32) -> ()) -> (); // only for HLS/OpenCL backend diff --git a/platforms/artic/intrinsics_vulkan.impala b/platforms/artic/intrinsics_vulkan.impala new file mode 100644 index 00000000..6f742cde --- /dev/null +++ b/platforms/artic/intrinsics_vulkan.impala @@ -0,0 +1,153 @@ +// no declarations are emitted for "device" functions +#[import(cc = "device", name = "barrier")] fn vulkan_barrier(u32) -> (); +#[import(cc = "device", name = "exp")] fn vulkan_expf(f32) -> f32; +#[import(cc = "device", name = "exp2")] fn vulkan_exp2f(f32) -> f32; +#[import(cc = "device", name = "log")] fn vulkan_logf(f32) -> f32; +#[import(cc = "device", name = "log2")] fn vulkan_log2f(f32) -> f32; +#[import(cc = "device", name = "pow")] fn vulkan_powf(f32, f32) -> f32; +#[import(cc = "device", name = "rsqrt")] fn vulkan_rsqrtf(f32) -> f32; +#[import(cc = "device", name = "sqrt")] fn vulkan_sqrtf(f32) -> f32; +#[import(cc = "device", name = "fabs")] fn vulkan_fabsf(f32) -> f32; +#[import(cc = "device", name = "sin")] fn vulkan_sinf(f32) -> f32; +#[import(cc = "device", name = "cos")] fn vulkan_cosf(f32) -> f32; +#[import(cc = "device", name = "tan")] fn vulkan_tanf(f32) -> f32; +#[import(cc = "device", name = "asin")] fn vulkan_asinf(f32) -> f32; +#[import(cc = "device", name = "acos")] fn vulkan_acosf(f32) -> f32; +#[import(cc = "device", name = "atan")] fn vulkan_atanf(f32) -> f32; +#[import(cc = "device", name = "erf")] fn vulkan_erff(f32) -> f32; +#[import(cc = "device", name = "atan2")] fn vulkan_atan2f(f32, f32) -> f32; +#[import(cc = "device", name = "fmod")] fn vulkan_fmodf(f32, f32) -> f32; +#[import(cc = "device", name = "floor")] fn vulkan_floorf(f32) -> f32; +#[import(cc = "device", name = "isinf")] fn vulkan_isinff(f32) -> i32; +#[import(cc = "device", name = "isnan")] fn vulkan_isnanf(f32) -> i32; +#[import(cc = "device", name = "isfinite")] fn vulkan_isfinitef(f32) -> i32; +#[import(cc = "device", name = "fma")] fn vulkan_fmaf(f32, f32, f32) -> f32; +#[import(cc = "device", name = "mad")] fn vulkan_madf(f32, f32, f32) -> f32; +#[import(cc = "device", name = "copysign")] fn vulkan_copysignf(f32, f32) -> f32; +#[import(cc = "device", name = "exp")] fn vulkan_exp(f64) -> f64; +#[import(cc = "device", name = "exp2")] fn vulkan_exp2(f64) -> f64; +#[import(cc = "device", name = "log")] fn vulkan_log(f64) -> f64; +#[import(cc = "device", name = "log2")] fn vulkan_log2(f64) -> f64; +#[import(cc = "device", name = "pow")] fn vulkan_pow(f64, f64) -> f64; +#[import(cc = "device", name = "rsqrt")] fn vulkan_rsqrt(f64) -> f64; +#[import(cc = "device", name = "sqrt")] fn vulkan_sqrt(f64) -> f64; +#[import(cc = "device", name = "fabs")] fn vulkan_fabs(f64) -> f64; +#[import(cc = "device", name = "sin")] fn vulkan_sin(f64) -> f64; +#[import(cc = "device", name = "cos")] fn vulkan_cos(f64) -> f64; +#[import(cc = "device", name = "tan")] fn vulkan_tan(f64) -> f64; +#[import(cc = "device", name = "asin")] fn vulkan_asin(f64) -> f64; +#[import(cc = "device", name = "acos")] fn vulkan_acos(f64) -> f64; +#[import(cc = "device", name = "atan")] fn vulkan_atan(f64) -> f64; +#[import(cc = "device", name = "erf")] fn vulkan_erf(f64) -> f64; +#[import(cc = "device", name = "atan2")] fn vulkan_atan2(f64, f64) -> f64; +#[import(cc = "device", name = "fmod")] fn vulkan_fmod(f64, f64) -> f64; +#[import(cc = "device", name = "floor")] fn vulkan_floor(f64) -> f64; +#[import(cc = "device", name = "isinf")] fn vulkan_isinf(f64) -> i32; +#[import(cc = "device", name = "isnan")] fn vulkan_isnan(f64) -> i32; +#[import(cc = "device", name = "isfinite")] fn vulkan_isfinite(f64) -> i32; +#[import(cc = "device", name = "fma")] fn vulkan_fma(f64, f64, f64) -> f64; +#[import(cc = "device", name = "mad")] fn vulkan_mad(f64, f64, f64) -> f64; +#[import(cc = "device", name = "copysign")] fn vulkan_copysign(f64, f64) -> f64; +#[import(cc = "device", name = "fmin")] fn vulkan_fminf(f32, f32) -> f32; +#[import(cc = "device", name = "fmax")] fn vulkan_fmaxf(f32, f32) -> f32; +#[import(cc = "device", name = "fmin")] fn vulkan_fmin(f64, f64) -> f64; +#[import(cc = "device", name = "fmax")] fn vulkan_fmax(f64, f64) -> f64; +#[import(cc = "device", name = "min")] fn vulkan_min(i32, i32) -> i32; +#[import(cc = "device", name = "max")] fn vulkan_max(i32, i32) -> i32; +#[import(cc = "device", name = "atomic_add")] fn vulkan_atomic_add_global(&mut addrspace(1)i32, i32) -> i32; +#[import(cc = "device", name = "atomic_add")] fn vulkan_atomic_add_shared(&mut addrspace(3)i32, i32) -> i32; +#[import(cc = "device", name = "atomic_min")] fn vulkan_atomic_min_global(&mut addrspace(1)i32, i32) -> i32; +#[import(cc = "device", name = "atomic_min")] fn vulkan_atomic_min_shared(&mut addrspace(3)i32, i32) -> i32; + +fn spv_vk_get_num_groups() = *spirv_get_builtin[&mut addrspace(8) simd[u32 * 3]](24 /* BuiltInNumWorkgroups */); +fn spv_vk_get_local_size() = *spirv_get_builtin[&mut addrspace(8) simd[u32 * 3]](25 /* BuiltInWorkgroupSize */); +fn spv_vk_get_group_id() = *spirv_get_builtin[&mut addrspace(8) simd[u32 * 3]](26 /* BuiltInWorkgroupId */); +fn spv_vk_get_local_id() = *spirv_get_builtin[&mut addrspace(8) simd[u32 * 3]](27 /* BuiltInLocalInvocationId */); +fn spv_vk_get_global_id() = *spirv_get_builtin[&mut addrspace(8) simd[u32 * 3]](28 /* BuiltInGlobalInvocationId */); + +fn @vulkan_get_global_size(dim: u32) -> i32 = (spv_vk_get_local_size()(dim) * spv_vk_get_num_groups()(dim)) as i32; + +fn @vulkan_accelerator(dev: i32) = Accelerator { + exec = @|body| |grid, block| { + let work_item = WorkItem { + tidx = @|| spv_vk_get_local_id()(0) as i32, + tidy = @|| spv_vk_get_local_id()(1) as i32, + tidz = @|| spv_vk_get_local_id()(2) as i32, + bidx = @|| spv_vk_get_group_id()(0) as i32, + bidy = @|| spv_vk_get_group_id()(1) as i32, + bidz = @|| spv_vk_get_group_id()(2) as i32, + gidx = @|| spv_vk_get_global_id()(0) as i32, + gidy = @|| spv_vk_get_global_id()(1) as i32, + gidz = @|| spv_vk_get_global_id()(2) as i32, + bdimx = @|| spv_vk_get_local_size()(0) as i32, + bdimy = @|| spv_vk_get_local_size()(1) as i32, + bdimz = @|| spv_vk_get_local_size()(2) as i32, + gdimx = @|| vulkan_get_global_size(0) as i32, + gdimy = @|| vulkan_get_global_size(1) as i32, + gdimz = @|| vulkan_get_global_size(2) as i32, + nblkx = @|| spv_vk_get_num_groups()(0) as i32, + nblky = @|| spv_vk_get_num_groups()(1) as i32, + nblkz = @|| spv_vk_get_num_groups()(2) as i32 + }; + vulkan_cs(dev, grid, block, || @body(work_item)) + }, + sync = @|| synchronize_vulkan(dev), + alloc = @|size| alloc_vulkan(dev, size), + alloc_unified = @|size| alloc_opencl_unified(dev, size), + barrier = @|| opencl_barrier(CLK_LOCAL_MEM_FENCE), +}; + +static vulkan_intrinsics = Intrinsics { + expf = vulkan_expf, + exp2f = vulkan_exp2f, + logf = vulkan_logf, + log2f = vulkan_log2f, + powf = vulkan_powf, + rsqrtf = vulkan_rsqrtf, + sqrtf = vulkan_sqrtf, + fabsf = vulkan_fabsf, + sinf = vulkan_sinf, + cosf = vulkan_cosf, + tanf = vulkan_tanf, + asinf = vulkan_asinf, + acosf = vulkan_acosf, + atanf = vulkan_atanf, + erff = vulkan_erff, + atan2f = vulkan_atan2f, + copysignf = vulkan_copysignf, + fmaf = vulkan_fmaf, + fmaxf = vulkan_fmaxf, + fminf = vulkan_fminf, + fmodf = vulkan_fmodf, + floorf = vulkan_floorf, + isinff = vulkan_isinff, + isnanf = vulkan_isnanf, + isfinitef = vulkan_isfinitef, + exp = vulkan_exp, + exp2 = vulkan_exp2, + log = vulkan_log, + log2 = vulkan_log2, + pow = vulkan_pow, + rsqrt = vulkan_rsqrt, + sqrt = vulkan_sqrt, + fabs = vulkan_fabs, + sin = vulkan_sin, + cos = vulkan_cos, + tan = vulkan_tan, + asin = vulkan_asin, + acos = vulkan_acos, + atan = vulkan_atan, + erf = vulkan_erf, + atan2 = vulkan_atan2, + copysign = vulkan_copysign, + fma = vulkan_fma, + fmax = vulkan_fmax, + fmin = vulkan_fmin, + fmod = vulkan_fmod, + floor = vulkan_floor, + isinf = vulkan_isinf, + isnan = vulkan_isnan, + isfinite = vulkan_isfinite, + min = vulkan_min, + max = vulkan_max, +}; diff --git a/platforms/artic/runtime.impala b/platforms/artic/runtime.impala index 456a9412..013125c5 100644 --- a/platforms/artic/runtime.impala +++ b/platforms/artic/runtime.impala @@ -31,6 +31,8 @@ #[import(cc = "C", name = "anydsl_print_string")] fn print_string(_: &[u8]) -> (); #[import(cc = "C", name = "anydsl_print_flush")] fn print_flush() -> (); +#[import(cc = "C", name = "anydsl_load_offloaded")] fn runtime_load_offloaded(_device: i32, _filename: &[u8], _name: &[u8], _size: &mut u64) -> &[u8]; + // TODO //struct Buffer[T] { // data : &mut [T], @@ -123,6 +125,8 @@ fn @alloc_levelzero(dev: i32, size: i64) = alloc(runtime_device(5, dev), size); fn @alloc_levelzero_host(dev: i32, size: i64) = alloc_host(runtime_device(5, dev), size); fn @alloc_levelzero_unified(dev: i32, size: i64) = alloc_unified(runtime_device(5, dev), size); fn @synchronize_levelzero(dev: i32) = runtime_synchronize(runtime_device(5, dev)); +fn @synchronize_vulkan(dev: i32) = runtime_synchronize(runtime_device(6, dev)); +fn @alloc_vulkan(dev: i32, size: i64) = alloc(runtime_device(6, dev), size); fn @copy(src: Buffer, dst: Buffer) = runtime_copy(src.device, src.data, 0, dst.device, dst.data, 0, src.size); fn @copy_offset(src: Buffer, off_src: i64, dst: Buffer, off_dst: i64, size: i64) = runtime_copy(src.device, src.data, off_src, dst.device, dst.data, off_dst, size); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 25e0f18c..3c9fb157 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -139,6 +139,16 @@ if(pal_FOUND) endif() set(AnyDSL_runtime_HAS_PAL_SUPPORT ${pal_FOUND} CACHE INTERNAL "enables PAL support") +find_package(shady) +find_package(Vulkan) +if(Vulkan_FOUND AND shady_FOUND) + add_library(runtime_vulkan STATIC vulkan_platform.cpp vulkan_platform.h) + target_include_directories(runtime_vulkan PRIVATE ${Vulkan_INCLUDE_DIRS}) + target_link_libraries(runtime_vulkan PRIVATE runtime_base ${Vulkan_LIBRARIES} shady::runtime shady::api shady::driver) + list(APPEND RUNTIME_PLATFORMS runtime_vulkan) +endif() +set(AnyDSL_runtime_HAS_Vulkan_SUPPORT ${Vulkan_FOUND} CACHE INTERNAL "enables Vulkan support") + # look for LLVM for nvptx and gcn find_package(LLVM CONFIG) if(LLVM_FOUND) diff --git a/src/anydsl_runtime.cpp b/src/anydsl_runtime.cpp index d7582496..ea2b5a80 100644 --- a/src/anydsl_runtime.cpp +++ b/src/anydsl_runtime.cpp @@ -37,6 +37,7 @@ struct RuntimeSingleton { register_hsa_platform(&runtime); register_pal_platform(&runtime); register_levelzero_platform(&runtime); + register_vulkan_platform(&runtime); } static std::pair detect_profile_level() { diff --git a/src/anydsl_runtime.h b/src/anydsl_runtime.h index 80bc6bed..25f70691 100644 --- a/src/anydsl_runtime.h +++ b/src/anydsl_runtime.h @@ -18,7 +18,8 @@ enum { ANYDSL_OPENCL = 2, ANYDSL_HSA = 3, ANYDSL_PAL = 4, - ANYDSL_LEVELZERO = 5 + ANYDSL_LEVELZERO = 5, + ANYDSL_Vulkan = 6 }; AnyDSL_runtime_API void anydsl_info(void); diff --git a/src/anydsl_runtime.hpp b/src/anydsl_runtime.hpp index 11789ac6..27f33d76 100644 --- a/src/anydsl_runtime.hpp +++ b/src/anydsl_runtime.hpp @@ -13,7 +13,8 @@ enum class Platform : int32_t { OpenCL = ANYDSL_OPENCL, HSA = ANYDSL_HSA, PAL = ANYDSL_PAL, - LevelZero = ANYDSL_LEVELZERO + LevelZero = ANYDSL_LEVELZERO, + Vulkan = ANYDSL_Vulkan }; struct Device { diff --git a/src/anydsl_runtime_config.h.in b/src/anydsl_runtime_config.h.in index 999776ba..1f65151c 100644 --- a/src/anydsl_runtime_config.h.in +++ b/src/anydsl_runtime_config.h.in @@ -11,6 +11,8 @@ #cmakedefine AnyDSL_runtime_HAS_LEVELZERO_SUPPORT #cmakedefine AnyDSL_runtime_HAS_HSA_SUPPORT #cmakedefine AnyDSL_runtime_HAS_PAL_SUPPORT +#cmakedefine AnyDSL_runtime_HAS_SHADY_SUPPORT +#cmakedefine AnyDSL_runtime_HAS_Vulkan_SUPPORT #cmakedefine AnyDSL_runtime_HAS_TBB_SUPPORT diff --git a/src/platform.h b/src/platform.h index b8d719a4..34f7be47 100644 --- a/src/platform.h +++ b/src/platform.h @@ -14,6 +14,7 @@ void register_opencl_platform(Runtime*); void register_hsa_platform(Runtime*); void register_pal_platform(Runtime*); void register_levelzero_platform(Runtime*); +void register_vulkan_platform(Runtime*); /// A runtime platform. Exposes a set of devices, a copy function, /// and functions to allocate and release memory. diff --git a/src/runtime.cpp b/src/runtime.cpp index a698fa84..bdb4ce22 100644 --- a/src/runtime.cpp +++ b/src/runtime.cpp @@ -23,6 +23,9 @@ void register_pal_platform(Runtime* runtime) { runtime->register_platformregister_platform("Level Zero"); } #endif +#ifndef AnyDSL_runtime_HAS_Vulkan_SUPPORT +void register_vulkan_platform(Runtime* runtime) { runtime->register_platform("Vulkan"); } +#endif Runtime::Runtime(std::pair profile) : profile_(profile) diff --git a/src/vulkan_platform.cpp b/src/vulkan_platform.cpp new file mode 100644 index 00000000..4f97ee10 --- /dev/null +++ b/src/vulkan_platform.cpp @@ -0,0 +1,673 @@ +#include "vulkan_platform.h" + +namespace shady { +extern "C" { +#include "shady/jit/vulkan.h" +#include "shady/be/spirv.h" +} +} + +const auto khr_validation = "VK_LAYER_KHRONOS_validation"; + +#define CHECK(stuff) { \ + auto rslt = stuff; \ + if (rslt != VK_SUCCESS) \ + error("error, failed %", #stuff); \ +} + +template +void insert_pnext(T& base, U& append) { + assert(base.pNext == nullptr); + append.pNext = base.pNext; + base.pNext = &append; +} + +inline std::vector query_layers_available() { + uint32_t count; + vkEnumerateInstanceLayerProperties(&count, nullptr); + std::vector layers(count); + vkEnumerateInstanceLayerProperties(&count, layers.data()); + return layers; +} + +inline std::vector query_extensions_available() { + uint32_t count; + vkEnumerateInstanceExtensionProperties(nullptr, &count, nullptr); + std::vector exts(count); + vkEnumerateInstanceExtensionProperties(nullptr, &count, exts.data()); + return exts; +} + +inline bool is_ext_available(std::vector& ext_props, std::string ext_name) { + for (auto& ext : ext_props) { + if (strcmp(ext.extensionName, ext_name.c_str()) == 0) + return true; + } + return false; +} + +VulkanPlatform::VulkanPlatform(Runtime* runtime) : Platform(runtime) { + auto available_layers = query_layers_available(); + auto available_instance_extensions = query_extensions_available(); + + std::vector enabled_layers; + std::vector enabled_instance_extensions { + "VK_KHR_external_memory_capabilities" + }; + + bool should_enable_validation = true; +#ifdef NDEBUG + should_enable_validation = false; +#endif + if (should_enable_validation) { + for (auto& layer : available_layers) { + if (strcmp(khr_validation, layer.layerName) == 0) { + enabled_layers.push_back(khr_validation); + goto validation_done; + } + } + info("Warning: validation enabled but layers not present"); + } + validation_done: + + auto app_info = VkApplicationInfo { + .sType = VK_STRUCTURE_TYPE_APPLICATION_INFO, + .pApplicationName = "AnyDSL Runtime", + .apiVersion = VK_API_VERSION_1_2, + }; + auto create_info = VkInstanceCreateInfo { + .sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, + .pNext = nullptr, + .pApplicationInfo = &app_info, + .enabledLayerCount = (uint32_t) enabled_layers.size(), + .ppEnabledLayerNames = enabled_layers.data(), + .enabledExtensionCount = (uint32_t) enabled_instance_extensions.size(), + .ppEnabledExtensionNames = enabled_instance_extensions.data(), + }; + vkCreateInstance(&create_info, nullptr, &instance); + + uint32_t physical_devices_count; + vkEnumeratePhysicalDevices(instance, &physical_devices_count, nullptr); + physical_devices.resize(physical_devices_count); + vkEnumeratePhysicalDevices(instance, &physical_devices_count, physical_devices.data()); + + debug("Available Vulkan physical devices: "); + size_t i = 0; + for (auto& dev : physical_devices) { + usable_devices.emplace_back(std::make_unique(*this, dev, i)); + i++; + } + debug("Vulkan platform successfully initialized"); +} + +VulkanPlatform::~VulkanPlatform() { + usable_devices.clear(); + vkDestroyInstance(instance, nullptr); +} + +VulkanPlatform::Device::Device(VulkanPlatform& platform, VkPhysicalDevice physical_device, size_t device_id) +: platform_(platform), physical_device(physical_device), device_id(device_id) { + uint32_t exts_count; + vkEnumerateDeviceExtensionProperties(physical_device, nullptr, &exts_count, nullptr); + std::vector available_device_extensions(exts_count); + vkEnumerateDeviceExtensionProperties(physical_device, nullptr, &exts_count, available_device_extensions.data()); + + std::vector enabled_device_extensions { + "VK_KHR_buffer_device_address", + "VK_KHR_shader_non_semantic_info" + }; + + // Use this to import host memory as GPU-visible memory, otherwise use a fallback path that copies when uploading/downloading + if (is_ext_available(available_device_extensions, "VK_EXT_external_memory_host")) { + enabled_device_extensions.push_back("VK_EXT_external_memory_host"); + insert_pnext(properties, external_memory_host_properties); + can_import_host_memory = true; + } + + vkGetPhysicalDeviceProperties2(physical_device, &properties); + auto& device_properties = properties.properties; + + debug(" GPU%:", device_id); + debug(" Device name: %", device_properties.deviceName); + debug(" Vulkan version %.%.%", VK_VERSION_MAJOR(device_properties.apiVersion), VK_VERSION_MINOR(device_properties.apiVersion), VK_VERSION_PATCH(device_properties.apiVersion)); + + if (can_import_host_memory) { + debug(" Min imported host ptr alignment: %", external_memory_host_properties.minImportedHostPointerAlignment); + if (external_memory_host_properties.minImportedHostPointerAlignment == 0xFFFFFFFF) + error("Device does not report minimum host pointer alignment"); + } + + uint32_t queue_families_count; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, &queue_families_count, nullptr); + std::vector queue_families(queue_families_count); + vkGetPhysicalDeviceQueueFamilyProperties(physical_device, &queue_families_count, queue_families.data()); + int compute_queue_family = -1; + int q = 0; + for (auto& queue_f : queue_families) { + bool has_gfx = (queue_f.queueFlags & 0x00000001) != 0; + bool has_compute = (queue_f.queueFlags & 0x00000002) != 0; + bool has_xfer = (queue_f.queueFlags & 0x00000004) != 0; + bool has_sparse = (queue_f.queueFlags & 0x00000008) != 0; + bool has_protected = (queue_f.queueFlags & 0x00000010) != 0; + + // TODO perform this intelligently + if (compute_queue_family == -1 && has_compute) + compute_queue_family = q; + q++; + } + std::vector queue_create_infos; + float one = 1.0f; + if (compute_queue_family != -1) { + queue_create_infos.push_back(VkDeviceQueueCreateInfo { + .sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, + .pNext = nullptr, + .flags = 0, + .queueFamilyIndex = (uint32_t) compute_queue_family, + .queueCount = 1, + .pQueuePriorities = &one + }); + } else { + assert(false && "unsuitable device"); + } + + auto bda_features = VkPhysicalDeviceBufferDeviceAddressFeaturesKHR { + .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_KHR, + .pNext = nullptr, + .bufferDeviceAddress = true, + }; + auto vk11_features = VkPhysicalDeviceVulkan11Features { + .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES, + .pNext = &bda_features, + .variablePointersStorageBuffer = true, + .variablePointers = true, + }; + auto enabled_features = VkPhysicalDeviceFeatures2 { + .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2, + .pNext = &vk11_features, + .features = { + .vertexPipelineStoresAndAtomics = true, + .fragmentStoresAndAtomics = true, + .shaderInt64 = true, + // .shaderInt16 = true, + } + }; + + auto device_create_info = VkDeviceCreateInfo { + .sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, + .pNext = &enabled_features, + .flags = 0, + .queueCreateInfoCount = (uint32_t) queue_create_infos.size(), + .pQueueCreateInfos = queue_create_infos.data(), + .enabledLayerCount = 0, + .ppEnabledLayerNames = nullptr, + .enabledExtensionCount = (uint32_t) enabled_device_extensions.size(), + .ppEnabledExtensionNames = enabled_device_extensions.data(), + .pEnabledFeatures = nullptr // controlled via VkPhysicalDeviceFeatures2 + }; + CHECK(vkCreateDevice(physical_device, &device_create_info, nullptr, &handle_)); + vkGetDeviceQueue(handle_, compute_queue_family, 0, &queue); + + auto cmd_pool_create_info = VkCommandPoolCreateInfo { + .sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO, + .pNext = nullptr, + .flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT, + .queueFamilyIndex = (uint32_t) compute_queue_family, + }; + CHECK(vkCreateCommandPool(handle_, &cmd_pool_create_info, nullptr, &cmd_pool)); + + // Load function pointers +#define f(s) extension_fns.s = (PFN_##s) vkGetDeviceProcAddr(handle_, #s); + DevicesExtensionsFunctions(f) +#undef f + + bool device_ok = shady::shd_rt_vk_check_physical_device_suitability(physical_device, &shady_caps_); + assert(device_ok); + target_config_ = shady::shd_rt_vk_get_device_target_config(&platform_.compiler_config_, &shady_caps_); +} + +VulkanPlatform::Device::~Device() { + vkDestroyCommandPool(handle_, cmd_pool, nullptr); + kernels.clear(); + //if (!resources.empty()) { + // info("Some vulkan resources were not released. Releasing those automatically..."); + // resources.clear(); + //} + vkDestroyDevice(handle_, nullptr); +} + +uint32_t VulkanPlatform::Device::find_suitable_memory_type(uint32_t memory_type_bits, VkMemoryPropertyFlags memory_flags, VkMemoryHeapFlags heap_flags) { + VkPhysicalDeviceMemoryProperties device_memory_properties; + vkGetPhysicalDeviceMemoryProperties(physical_device, &device_memory_properties); + for (size_t bit = 0; bit < 32; bit++) { + auto& memory_type = device_memory_properties.memoryTypes[bit]; + auto& memory_heap = device_memory_properties.memoryHeaps[memory_type.heapIndex]; + + if ((memory_type_bits & (1 << bit)) != 0) { + if ((memory_type.propertyFlags & memory_flags) == memory_flags && (memory_heap.flags & heap_flags) == heap_flags) + return bit; + } + } + assert(false && "Unable to find a suitable memory type"); +} + +VkDeviceMemory VulkanPlatform::Device::allocate_memory(VkDeviceSize size, uint32_t memory_type_bits, VkMemoryPropertyFlags memory_flags, VkMemoryHeapFlags heap_flags) { + auto allocate_flags = VkMemoryAllocateFlagsInfo { + .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO, + .pNext = nullptr, + .flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR, + .deviceMask = 0 + }; + + auto allocation_info = VkMemoryAllocateInfo { + .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO, + .pNext = &allocate_flags, + .allocationSize = size, // the driver might want padding ! + .memoryTypeIndex = find_suitable_memory_type(memory_type_bits, memory_flags, heap_flags), + }; + VkDeviceMemory memory; + vkAllocateMemory(handle_, &allocation_info, nullptr, &memory); + + return memory; +} + +std::pair VulkanPlatform::Device::import_host_memory(void *ptr, size_t size) { + assert(can_import_host_memory && "This device does not support importing host memory"); + + size_t alignment = external_memory_host_properties.minImportedHostPointerAlignment; + + // Align stuff + size_t mask = ~(alignment - 1); + size_t host_ptr = (size_t)ptr; + size_t aligned_host_ptr = host_ptr & mask; + + size_t end = host_ptr + size; + size_t aligned_end = ((end + alignment - 1) / alignment) * alignment; + size_t aligned_size = aligned_end - aligned_host_ptr; + + // where the memory we wanted to import will actually start + size_t offset = host_ptr - aligned_host_ptr; + + // Find the corresponding device memory type index + VkMemoryHostPointerPropertiesEXT host_ptr_properties { + .sType = VK_STRUCTURE_TYPE_MEMORY_HOST_POINTER_PROPERTIES_EXT, + }; + CHECK(extension_fns.vkGetMemoryHostPointerPropertiesEXT(handle_, VK_EXTERNAL_MEMORY_HANDLE_TYPE_HOST_ALLOCATION_BIT_EXT, (void*)aligned_host_ptr, &host_ptr_properties)); + uint32_t memory_type = find_suitable_memory_type(host_ptr_properties.memoryTypeBits, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + + // Import memory + auto import_ptr_info = VkImportMemoryHostPointerInfoEXT { + .sType = VK_STRUCTURE_TYPE_IMPORT_MEMORY_HOST_POINTER_INFO_EXT, + .pNext = nullptr, + .handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_HOST_ALLOCATION_BIT_EXT, + .pHostPointer = (void*) aligned_host_ptr, + }; + auto allocation_info = VkMemoryAllocateInfo { + .sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO, + .pNext = &import_ptr_info, + .allocationSize = (VkDeviceSize) aligned_size, + .memoryTypeIndex = memory_type + }; + VkDeviceMemory imported_memory; + CHECK(vkAllocateMemory(handle_, &allocation_info, nullptr, &imported_memory)); + return std::make_pair(imported_memory, offset); +} + +VulkanPlatform::Buffer::Buffer(Device& device, size_t size, BackingStorage backing, VkBufferUsageFlags2 usage) : Resource(device) { + VkBufferCreateInfo buffer_create_info { + .sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, + .pNext = nullptr, + .flags = 0, + .size = (VkDeviceSize) size, + .usage = static_cast(usage), + .sharingMode = VK_SHARING_MODE_EXCLUSIVE, + .queueFamilyIndexCount = 0, + .pQueueFamilyIndices = nullptr, + }; + + auto create_buffer = [&]() { vkCreateBuffer(device.handle_, &buffer_create_info, nullptr, &handle_); }; + + if (const auto* import_host = std::get_if(&backing)) { + VkExternalMemoryBufferCreateInfo external_mem_buffer_create_info { + .sType = VK_STRUCTURE_TYPE_EXTERNAL_MEMORY_BUFFER_CREATE_INFO, + .pNext = nullptr, + .handleTypes = VK_EXTERNAL_MEMORY_HANDLE_TYPE_HOST_ALLOCATION_BIT_EXT + }; + insert_pnext(buffer_create_info, external_mem_buffer_create_info); + create_buffer(); + + size_t imported_offset; + std::tie(device_memory_, imported_offset) = device.import_host_memory(import_host->host_memory_, size); + vkBindBufferMemory(device.handle_, handle_, device_memory_, imported_offset); + } else if (std::get_if(&backing)) { + create_buffer(); + VkMemoryRequirements memory_requirements; + vkGetBufferMemoryRequirements(device.handle_, handle_, &memory_requirements); + device_memory_ = device.allocate_memory(memory_requirements.size, memory_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); + vkBindBufferMemory(device.handle_, handle_, device_memory_, 0); + } else if (std::get_if(&backing)) { + create_buffer(); + VkMemoryRequirements memory_requirements; + vkGetBufferMemoryRequirements(device.handle_, handle_, &memory_requirements); + device_memory_ = device.allocate_memory(memory_requirements.size, memory_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + vkBindBufferMemory(device.handle_, handle_, device_memory_, 0); + } else { + abort(); + } + + if (usage & VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT) { + VkBufferDeviceAddressInfoKHR bda_info{ + .sType = VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO_KHR, + .pNext = nullptr, + .buffer = handle_ + }; + device_address_ = device.extension_fns.vkGetBufferDeviceAddressKHR(device.handle_, &bda_info); + assert(device_address_ != 0 && "vkGetBufferDeviceAddress failed"); + } +} + +uint64_t VulkanPlatform::Device::create_buffer_resource(size_t size, Buffer::BackingStorage backing, VkBufferUsageFlags usage) { + std::unique_ptr buffer = std::make_unique(*this, size, backing, usage); + + assert(buffer->device_address_); + auto& b = *(buffers_[buffer->device_address_] = std::move(buffer)); + + return b.device_address_; +} + +void* VulkanPlatform::alloc(DeviceId dev, int64_t size) { + auto& device = usable_devices[dev]; + return reinterpret_cast(device->create_buffer_resource(size, Buffer::DeviceMemory(), Buffer::ALL_BUFFER_USAGE)); +} + +void* VulkanPlatform::alloc_host(DeviceId dev, int64_t size) { + auto& device = usable_devices[dev]; + return reinterpret_cast(device->create_buffer_resource(size, Buffer::HostMemory(), Buffer::ALL_BUFFER_USAGE)); +} + +void* VulkanPlatform::get_device_ptr(DeviceId dev, void *ptr) { + command_unavailable("get_device_ptr"); +} + +void VulkanPlatform::release(DeviceId dev, void *ptr) { + if (ptr == nullptr) + return; + + auto& device = usable_devices[dev]; + auto found = device->buffers_.find(reinterpret_cast(ptr)); + + if (found != device->buffers_.end()) { + device->buffers_.erase(found); + return; + } + + assert(false && "Could not find such a buffer to release"); +} + +void VulkanPlatform::release_host(DeviceId dev, void *ptr) { + release(dev, ptr); +} + +VulkanPlatform::Buffer::~Buffer() { + if (device_memory_) + vkFreeMemory(device_.handle_, device_memory_, nullptr); + vkDestroyBuffer(device_.handle_, handle_, nullptr); +} + +VulkanPlatform::Kernel::Kernel(Device& device, std::string file_name, std::string kernel_name) : device_(device) { + shady::TargetConfig specialized_target = device_.target_config_; + specialized_target.execution_model = shady::ShdExecutionModelCompute; + specialized_target.entry_point = kernel_name.c_str(); + + std::string program_src = device_.platform_.runtime_->load_file(file_name); + shd_driver_load_source_file(&device_.platform_.compiler_config_, &device_.target_config_, shady::SrcSPIRV, program_src.size(), program_src.c_str(), "test", &shady_module_); + // TODO: this will be removed in a future version of Shady + shady::CompilerConfig specialized_config = device_.platform_.compiler_config_; + shady::SPVBackendConfig backend_config; + shady::shd_jit_vk_get_compiler_config_for_device(&device_.shady_caps_, &device_.target_config_, &backend_config, &specialized_config); + shady::shd_jit_vk_compile_module(&shady_module_, &specialized_target, &backend_config, &specialized_config); + size_t spirv_size; + char* spirv_bytes; + shady::shd_emit_spirv(&specialized_config, &backend_config, shady_module_, &spirv_size, &spirv_bytes); + + size_t interface_size; + shady::shd_rt_vk_get_module_interface(shady_module_, &interface_size, nullptr); + interface.resize(interface_size); + shady::shd_rt_vk_get_module_interface(shady_module_, &interface_size, interface.data()); + + for (auto& e : interface) { + if (e.dst_kind == shady::RuntimeInterfaceItem::SHD_RII_Dst_PushConstant) + push_constant_size = std::max(push_constant_size, e.dst_details.push_constant.offset + e.dst_details.push_constant.size); + } + + auto shader_module_create_info = VkShaderModuleCreateInfo { + .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, + .pNext = nullptr, + .flags = 0, + .codeSize = spirv_size, + .pCode = reinterpret_cast(spirv_bytes), + }; + CHECK(vkCreateShaderModule(device.handle_, &shader_module_create_info, nullptr, &shader_module)); + + auto stage = VkPipelineShaderStageCreateInfo { + .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + .pNext = nullptr, + .flags = 0, + .stage = VK_SHADER_STAGE_COMPUTE_BIT, + .module = shader_module, + .pName = kernel_name.c_str(), + .pSpecializationInfo = nullptr, + }; + + std::vector push_constants { + VkPushConstantRange { + .stageFlags = VK_SHADER_STAGE_COMPUTE_BIT, + .offset = 0, + .size = static_cast(push_constant_size) + } + }; + auto layout_create_info = VkPipelineLayoutCreateInfo { + .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, + .pNext = nullptr, + .flags = 0, + .setLayoutCount = 0, + .pSetLayouts = nullptr, + .pushConstantRangeCount = (uint32_t) push_constants.size(), + .pPushConstantRanges = push_constants.data(), + }; + CHECK(vkCreatePipelineLayout(device.handle_, &layout_create_info, nullptr, &layout)); + + auto compute_pipeline_create_info = VkComputePipelineCreateInfo { + .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, + .pNext = nullptr, + .flags = 0, + .stage = stage, + .layout = layout, + .basePipelineHandle = VK_NULL_HANDLE, + .basePipelineIndex = 0, + }; + CHECK(vkCreateComputePipelines(device.handle_, nullptr, 1, &compute_pipeline_create_info, nullptr, &pipeline)); +} + +VulkanPlatform::Kernel* VulkanPlatform::Device::load_kernel(const std::string& filename, const std::string& kernel_name) { + auto key = filename + "::" + kernel_name; + auto ki = kernels.find(key); + if (ki == kernels.end()) { + auto [i,b] = kernels.emplace(key, std::make_unique(*this, filename, kernel_name)); + return &*i->second; + } + + return ki->second.get(); +} + +void VulkanPlatform::Kernel::setup(VkCommandBuffer cmdbuf, const LaunchParams& launch_params) { + vkCmdBindPipeline(cmdbuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); + std::vector push_constants; + push_constants.resize(push_constant_size); + + for (auto& e : interface) { + if (e.dst_kind == shady::RuntimeInterfaceItem::SHD_RII_Dst_PushConstant) { + switch (e.src_kind) { + case shady::RuntimeInterfaceItem::SHD_RII_Src_Param: + assert(e.dst_details.push_constant.size == launch_params.args.sizes[e.src_details.param.param_idx]); + memcpy(reinterpret_cast(push_constants.data()) + e.dst_details.push_constant.offset, launch_params.args.data[e.src_details.param.param_idx], e.dst_details.push_constant.size); + break; + default: + error("TODO"); + //case shady::RuntimeInterfaceItem::SHD_RII_Src_TmpAllocation: + // break; + //case shady::RuntimeInterfaceItem::SHD_RII_Src_LiftedConstant: + // break; + //case shady::RuntimeInterfaceItem::SHD_RII_Src_ScratchBuffer: + // break; + } + + } else { + error("todo: implement descriptors"); + } + } + + vkCmdPushConstants(cmdbuf, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, push_constant_size, push_constants.data()); + vkCmdDispatch(cmdbuf, launch_params.grid[0] / launch_params.block[0], launch_params.grid[1] / launch_params.block[1], launch_params.grid[2] / launch_params.block[2]); +} + +void VulkanPlatform::launch_kernel(DeviceId dev, const LaunchParams &launch_params) { + auto& device = usable_devices[dev]; + auto kernel = device->load_kernel(launch_params.file_name, launch_params.kernel_name); + + device->execute_command_buffer_oneshot([&](VkCommandBuffer cmd_buf) { + kernel->setup(cmd_buf, launch_params); + }); +} + +void VulkanPlatform::synchronize(DeviceId dev) { + // TODO: don't wait for idle everywhere +} + +VkCommandBuffer VulkanPlatform::Device::obtain_command_buffer() { + if (spare_cmd_bufs.size() > 0) { + VkCommandBuffer cmd_buf = spare_cmd_bufs.back(); + spare_cmd_bufs.pop_back(); + return cmd_buf; + } + auto cmd_buf_create_info = VkCommandBufferAllocateInfo { + .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, + .pNext = nullptr, + .commandPool = cmd_pool, + .level = VK_COMMAND_BUFFER_LEVEL_PRIMARY, + .commandBufferCount = 1 + }; + VkCommandBuffer cmd_buf; + CHECK(vkAllocateCommandBuffers(handle_, &cmd_buf_create_info, &cmd_buf)); + return cmd_buf; +} + +void VulkanPlatform::Device::return_command_buffer(VkCommandBuffer cmd_buf) { + vkResetCommandBuffer(cmd_buf, 0); + spare_cmd_bufs.push_back(cmd_buf); +} + +void VulkanPlatform::Device::execute_command_buffer_oneshot(std::function fn) { + VkCommandBuffer cmd_buf = obtain_command_buffer(); + auto begin_command_buffer_info = VkCommandBufferBeginInfo { + .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, + .pNext = nullptr, + .flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, + .pInheritanceInfo = nullptr, + }; + CHECK(vkBeginCommandBuffer(cmd_buf, &begin_command_buffer_info)); + fn(cmd_buf); + CHECK(vkEndCommandBuffer(cmd_buf)); + auto submit_info = VkSubmitInfo { + .sType = VK_STRUCTURE_TYPE_SUBMIT_INFO, + .pNext = nullptr, + .waitSemaphoreCount = 0, + .pWaitSemaphores = nullptr, + .pWaitDstStageMask = nullptr, + .commandBufferCount = 1, + .pCommandBuffers = &cmd_buf, + .signalSemaphoreCount = 0, + .pSignalSemaphores = nullptr, + }; + CHECK(vkQueueSubmit(queue, 1, &submit_info, VK_NULL_HANDLE)); + CHECK(vkDeviceWaitIdle(handle_)); + return_command_buffer(cmd_buf); +} + +void VulkanPlatform::copy(DeviceId dev_src, const void *src, int64_t offset_src, DeviceId dev_dst, void *dst, int64_t offset_dst, int64_t size) { + command_unavailable("copy"); +} + +void VulkanPlatform::copy_from_host(const void *src, int64_t offset_src, DeviceId dev_dst, void *dst, int64_t offset_dst, int64_t size) { + auto& device = usable_devices[dev_dst]; + auto dst_buffer = device->get_buffer_by_device_address(reinterpret_cast(dst)); + + std::unique_ptr tmp_buffer; + + void* host_ptr = (void*)((size_t)src + offset_src); + // Import host memory and wrap it in a buffer + if (device->can_import_host_memory) { + tmp_buffer = std::make_unique(*device, size, Buffer::ImportedHostMemory { host_ptr }, VK_BUFFER_USAGE_TRANSFER_SRC_BIT); + } else { + tmp_buffer = std::make_unique(*device, size, Buffer::HostMemory { }, VK_BUFFER_USAGE_TRANSFER_SRC_BIT); + void* mapped = nullptr; + CHECK(vkMapMemory(device->handle_, tmp_buffer->device_memory_, 0, size, 0, &mapped)); + assert(mapped != nullptr); + memcpy(mapped, host_ptr, size); + vkUnmapMemory(device->handle_, tmp_buffer->device_memory_); + } + + device->execute_command_buffer_oneshot([&](VkCommandBuffer cmd_buf) { + VkBufferCopy copy_region { + .srcOffset = 0, + .dstOffset = (VkDeviceSize) offset_dst, + .size = (VkDeviceSize) size, + }; + vkCmdCopyBuffer(cmd_buf, tmp_buffer->handle_, dst_buffer->handle_, 1, ©_region); + }); +} + +void VulkanPlatform::copy_to_host(DeviceId dev_src, const void *src, int64_t offset_src, void *dst, int64_t offset_dst, int64_t size) { + auto& device = usable_devices[dev_src]; + auto src_buffer = device->get_buffer_by_device_address(reinterpret_cast(src)); + + std::unique_ptr tmp_buffer; + + void* host_ptr = (void*)((size_t)dst + offset_dst); + // Import host memory and wrap it in a buffer + if (device->can_import_host_memory) { + tmp_buffer = std::make_unique(*device, size, Buffer::ImportedHostMemory { host_ptr }, VK_BUFFER_USAGE_TRANSFER_DST_BIT); + } else { + tmp_buffer = std::make_unique(*device, size, Buffer::HostMemory { }, VK_BUFFER_USAGE_TRANSFER_DST_BIT); + } + + device->execute_command_buffer_oneshot([&](VkCommandBuffer cmd_buf) { + VkBufferCopy copy_region { + .srcOffset = (VkDeviceSize) offset_src, + .dstOffset = 0, + .size = (VkDeviceSize) size, + }; + vkCmdCopyBuffer(cmd_buf, src_buffer->handle_, tmp_buffer->handle_, 1, ©_region); + }); + + if (!device->can_import_host_memory) { + void* mapped = nullptr; + CHECK(vkMapMemory(device->handle_, tmp_buffer->device_memory_, 0, size, 0, &mapped)); + assert(mapped != nullptr); + memcpy(host_ptr, mapped, size); + vkUnmapMemory(device->handle_, tmp_buffer->device_memory_); + } +} + +const char *VulkanPlatform::device_name(DeviceId dev) const { + return usable_devices[dev]->properties.properties.deviceName; +} + +void register_vulkan_platform(Runtime* runtime) { + runtime->register_platform(); +} + +VulkanPlatform::Kernel::~Kernel() { + vkDestroyPipeline(device_.handle_, pipeline, nullptr); + vkDestroyPipelineLayout(device_.handle_, layout, nullptr); + vkDestroyShaderModule(device_.handle_, shader_module, nullptr); +} diff --git a/src/vulkan_platform.h b/src/vulkan_platform.h new file mode 100644 index 00000000..38f3f615 --- /dev/null +++ b/src/vulkan_platform.h @@ -0,0 +1,168 @@ +#ifndef ANYDSL_RUNTIME_VULKAN_PLATFORM_H +#define ANYDSL_RUNTIME_VULKAN_PLATFORM_H + +#include "platform.h" +#include + +namespace shady { +extern "C" { +#include "shady/runtime/vulkan.h" +} +} + +#include +#include + +/// Vulkan requires you to manually load certain function pointers, we use a macro to automate the boilerplate +#define DevicesExtensionsFunctions(f) \ + f(vkGetMemoryHostPointerPropertiesEXT) \ + f(vkGetBufferDeviceAddressKHR) + +class VulkanPlatform : public Platform { +public: + VulkanPlatform(Runtime* runtime); + ~VulkanPlatform() override; + +protected: + void *alloc(DeviceId dev, int64_t size) override; + void *alloc_host(DeviceId dev, int64_t size) override; + void *alloc_unified(DeviceId dev, int64_t size) override { command_unavailable("alloc_unified"); } + void *get_device_ptr(DeviceId dev, void *ptr) override; + void release(DeviceId dev, void *ptr) override; + void release_host(DeviceId dev, void *ptr) override; + + void launch_kernel(DeviceId dev, const LaunchParams &launch_params) override; + void synchronize(DeviceId dev) override; + + void copy(DeviceId dev_src, const void *src, int64_t offset_src, DeviceId dev_dst, void *dst, int64_t offset_dst, int64_t size) override; + void copy_from_host(const void *src, int64_t offset_src, DeviceId dev_dst, void *dst, int64_t offset_dst, int64_t size) override; + void copy_to_host(DeviceId dev_src, const void *src, int64_t offset_src, void *dst, int64_t offset_dst, int64_t size) override; + + size_t dev_count() const override { return usable_devices.size(); } + std::string name() const override { return "Vulkan"; } + + const char* device_name(DeviceId dev) const override; + bool device_check_feature_support(DeviceId, const char*) const override { return false; } + + struct Device; + + struct Resource { + Device& device_; + + Resource(Device& device) : device_(device) {} + virtual ~Resource() {}; + }; + + struct Buffer : public Resource { + VkBuffer handle_; + + void* host_address_ = nullptr; + VkDeviceAddress device_address_ = 0; + + VkDeviceMemory device_memory_; + + const static VkBufferUsageFlags2 ALL_BUFFER_USAGE = + VK_BUFFER_USAGE_2_STORAGE_BUFFER_BIT | + VK_BUFFER_USAGE_2_SHADER_DEVICE_ADDRESS_BIT | + VK_BUFFER_USAGE_2_TRANSFER_SRC_BIT | + VK_BUFFER_USAGE_2_TRANSFER_DST_BIT; + + + struct ImportedHostMemory { + void* host_memory_; + }; + + struct DeviceMemory {}; + struct HostMemory {}; + struct UnifiedMemory {}; + + using BackingStorage = std::variant; + friend Device; + friend Platform; + + Buffer(Device& device, size_t size, BackingStorage backing_storage, VkBufferUsageFlags2 usages = ALL_BUFFER_USAGE); + ~Buffer() override; + }; + + struct Kernel { + Device& device_; + + shady::Module* shady_module_; + std::vector interface; + size_t push_constant_size = 0; + + VkShaderModule shader_module; + VkPipelineLayout layout; + VkPipeline pipeline; + + Kernel(Device& device, std::string, std::string); + void setup(VkCommandBuffer, const LaunchParams &launch_params); + ~Kernel(); + }; + + struct ExtensionFns { +#define f(s) PFN_##s s; + DevicesExtensionsFunctions(f) +#undef f + }; + + struct Device { + VulkanPlatform& platform_; + VkPhysicalDevice physical_device; + VkDevice handle_ = nullptr; + size_t device_id; + + ExtensionFns extension_fns; + + VkPhysicalDeviceProperties2 properties = { + .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2, + }; + + bool can_import_host_memory = false; + VkPhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_properties { + .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_EXTERNAL_MEMORY_HOST_PROPERTIES_EXT, + .pNext = nullptr, + .minImportedHostPointerAlignment = 0xFFFFFFFF, + }; + + shady::ShadyVkrPhysicalDeviceCaps shady_caps_; + shady::TargetConfig target_config_; + + std::unordered_map> buffers_; + std::unordered_map> kernels; + + VkQueue queue; + VkCommandPool cmd_pool; + std::vector spare_cmd_bufs; + + Device(VulkanPlatform& platform, VkPhysicalDevice physical_device, size_t device_id); + ~Device(); + + uint32_t find_suitable_memory_type(uint32_t memory_type_bits, VkMemoryPropertyFlags, VkMemoryHeapFlags = 0); + VkDeviceMemory allocate_memory(VkDeviceSize, uint32_t memory_type_bits, VkMemoryPropertyFlags memory_flags, VkMemoryHeapFlags heap_flags = 0); + std::pair import_host_memory(void* ptr, size_t size); + + Buffer* get_buffer_by_device_address(VkDeviceAddress addr) { + auto found = buffers_.find(addr); + if (found != buffers_.end()) + return &*found->second; + return nullptr; + } + + uint64_t create_buffer_resource(size_t, Buffer::BackingStorage backing, VkBufferUsageFlags usage_flags); + + VkCommandBuffer obtain_command_buffer(); + void return_command_buffer(VkCommandBuffer cmd_buf); + void execute_command_buffer_oneshot(std::function fn); + + Kernel* load_kernel(const std::string&, const std::string&); + }; + + VkInstance instance; + std::vector physical_devices; + std::vector> usable_devices; + + shady::CompilerConfig compiler_config_ = shady::shd_default_compiler_config(); +}; + +#endif