|
10 | 10 | #import <MetalPerformanceShaders/MetalPerformanceShaders.h> |
11 | 11 | #import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h> |
12 | 12 | #import <Foundation/Foundation.h> |
| 13 | +#include <simd/simd.h> |
13 | 14 | #include <executorch/runtime/platform/log.h> |
14 | 15 | #include <executorch/runtime/core/exec_aten/exec_aten.h> |
15 | 16 | #include <executorch/backends/apple/metal/runtime/shims/et_metal.h> |
@@ -377,6 +378,58 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev |
377 | 378 | ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set int64_t value %lld at index %u", val, idx); |
378 | 379 | } |
379 | 380 |
|
| 381 | +void ETMetalKernelFunction::setArg(unsigned idx, uint32_t val) { |
| 382 | + if (!encoder_) { |
| 383 | + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); |
| 384 | + return; |
| 385 | + } |
| 386 | + |
| 387 | + [encoder_ setBytes:&val length:sizeof(uint32_t) atIndex:idx]; |
| 388 | + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set uint32_t value %u at index %u", val, idx); |
| 389 | +} |
| 390 | + |
| 391 | +void ETMetalKernelFunction::setArg(unsigned idx, float val) { |
| 392 | + if (!encoder_) { |
| 393 | + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); |
| 394 | + return; |
| 395 | + } |
| 396 | + |
| 397 | + [encoder_ setBytes:&val length:sizeof(float) atIndex:idx]; |
| 398 | + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set float value %f at index %u", val, idx); |
| 399 | +} |
| 400 | + |
| 401 | +void ETMetalKernelFunction::setArg(unsigned idx, bool val) { |
| 402 | + if (!encoder_) { |
| 403 | + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); |
| 404 | + return; |
| 405 | + } |
| 406 | + |
| 407 | + [encoder_ setBytes:&val length:sizeof(bool) atIndex:idx]; |
| 408 | + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bool value %s at index %u", val ? "true" : "false", idx); |
| 409 | +} |
| 410 | + |
| 411 | +void ETMetalKernelFunction::setArg(unsigned idx, const void* data, size_t size) { |
| 412 | + if (!encoder_) { |
| 413 | + ET_LOG(Error, "ETMetalKernelFunction::setArg: No active encoder"); |
| 414 | + return; |
| 415 | + } |
| 416 | + |
| 417 | + [encoder_ setBytes:data length:size atIndex:idx]; |
| 418 | + ET_LOG(Debug, "ETMetalKernelFunction::setArg: Set bytes at index %u (size: %zu)", idx, size); |
| 419 | +} |
| 420 | + |
| 421 | +void ETMetalKernelFunction::setArgUint3(unsigned idx, uint32_t x, uint32_t y, uint32_t z) { |
| 422 | + if (!encoder_) { |
| 423 | + ET_LOG(Error, "ETMetalKernelFunction::setArgUint3: No active encoder"); |
| 424 | + return; |
| 425 | + } |
| 426 | + |
| 427 | + // Use SIMD library's uint3 type which matches Metal shader's uint3 layout |
| 428 | + simd_uint3 val = {x, y, z}; |
| 429 | + [encoder_ setBytes:&val length:sizeof(simd_uint3) atIndex:idx]; |
| 430 | + ET_LOG(Debug, "ETMetalKernelFunction::setArgUint3: Set uint3{%u, %u, %u} at index %u", x, y, z, idx); |
| 431 | +} |
| 432 | + |
380 | 433 | void ETMetalKernelFunction::dispatchSingle(uint64_t length) { |
381 | 434 | if (!encoder_) { |
382 | 435 | ET_LOG(Error, "ETMetalKernelFunction::dispatchSingle: No active encoder"); |
@@ -502,6 +555,40 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev |
502 | 555 |
|
503 | 556 | } |
504 | 557 |
|
| 558 | +void ETMetalKernelFunction::dispatchThreadgroups(uint64_t gridX, uint64_t gridY, uint64_t gridZ, |
| 559 | + uint64_t threadsX, uint64_t threadsY, uint64_t threadsZ) { |
| 560 | + if (!encoder_) { |
| 561 | + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No active encoder"); |
| 562 | + return; |
| 563 | + } |
| 564 | + |
| 565 | + if (!cps_) { |
| 566 | + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: No compute pipeline state"); |
| 567 | + return; |
| 568 | + } |
| 569 | + |
| 570 | + // Calculate total threads per threadgroup |
| 571 | + uint64_t totalThreads = threadsX * threadsY * threadsZ; |
| 572 | + |
| 573 | + const auto maxThreadsPerGroup = static_cast<uint64_t>([cps_ maxTotalThreadsPerThreadgroup]); |
| 574 | + |
| 575 | + // Validate total thread count |
| 576 | + if (totalThreads > maxThreadsPerGroup) { |
| 577 | + ET_LOG(Error, "ETMetalKernelFunction::dispatchThreadgroups: Requested %llu total threads per threadgroup exceeds device maximum of %llu", |
| 578 | + (unsigned long long)totalThreads, (unsigned long long)maxThreadsPerGroup); |
| 579 | + return; |
| 580 | + } |
| 581 | + |
| 582 | + MTLSize threadgroupsPerGrid = MTLSizeMake(gridX, gridY, gridZ); |
| 583 | + MTLSize threadsPerThreadgroup = MTLSizeMake(threadsX, threadsY, threadsZ); |
| 584 | + |
| 585 | + [encoder_ dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; |
| 586 | + |
| 587 | + ET_LOG(Debug, "ETMetalKernelFunction::dispatchThreadgroups: Dispatched grid [%llu, %llu, %llu] with threadgroup [%llu, %llu, %llu]", |
| 588 | + (unsigned long long)gridX, (unsigned long long)gridY, (unsigned long long)gridZ, |
| 589 | + (unsigned long long)threadsX, (unsigned long long)threadsY, (unsigned long long)threadsZ); |
| 590 | +} |
| 591 | + |
505 | 592 | void ETMetalKernelFunction::runCommandBlock(std::function<void(void)> f) { |
506 | 593 | // Use dispatch_sync with the stream's serial queue for thread safety and synchronization |
507 | 594 | // This matches PyTorch's approach: dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ...) |
|
0 commit comments