|
10 | 10 | using System.Collections.Generic; |
11 | 11 | using System.Diagnostics.CodeAnalysis; |
12 | 12 | using System.Numerics; |
| 13 | +using System.Runtime.InteropServices; |
13 | 14 |
|
14 | 15 | namespace IronPython.Runtime { |
15 | 16 |
|
@@ -376,79 +377,77 @@ IEnumerator IEnumerable.GetEnumerator() { |
376 | 377 | #endregion |
377 | 378 | } |
378 | 379 |
|
379 | | - public sealed class MemoryBufferWrapper : IPythonBuffer { |
380 | | - private readonly ReadOnlyMemory<byte> _rom; |
381 | | - private readonly Memory<byte>? _memory; |
382 | | - private readonly BufferFlags _flags; |
| 380 | + public sealed class MemoryBufferProtocolWrapper<T> : IBufferProtocol where T : unmanaged { |
| 381 | + private readonly ReadOnlyMemory<T> _rom; |
| 382 | + private readonly Memory<T>? _memory; |
| 383 | + private readonly string? _format; |
383 | 384 |
|
384 | | - public MemoryBufferWrapper(ReadOnlyMemory<byte> memory, BufferFlags flags) { |
| 385 | + public MemoryBufferProtocolWrapper(ReadOnlyMemory<T> memory, string? format = null) { |
385 | 386 | _rom = memory; |
386 | 387 | _memory = null; |
387 | | - _flags = flags; |
| 388 | + _format = format; |
388 | 389 | } |
389 | 390 |
|
390 | | - public MemoryBufferWrapper(Memory<byte> memory, BufferFlags flags) { |
| 391 | + public MemoryBufferProtocolWrapper(Memory<T> memory, string? format = null) { |
391 | 392 | _rom = memory; |
392 | 393 | _memory = memory; |
393 | | - _flags = flags; |
| 394 | + _format = format; |
394 | 395 | } |
395 | 396 |
|
396 | | - public void Dispose() { } |
397 | | - |
398 | | - public object Object => _memory ?? _rom; |
| 397 | + public IPythonBuffer? GetBuffer(BufferFlags flags, bool throwOnError) { |
| 398 | + if (_memory.HasValue) { |
| 399 | + return new MemoryBufferWrapper(this, flags); |
| 400 | + } |
399 | 401 |
|
400 | | - public bool IsReadOnly => !_memory.HasValue; |
| 402 | + if (flags.HasFlag(BufferFlags.Writable)) { |
| 403 | + if (throwOnError) { |
| 404 | + throw Operations.PythonOps.BufferError("ReadOnlyMemory is not writable."); |
| 405 | + } |
| 406 | + return null; |
| 407 | + } |
401 | 408 |
|
402 | | - public ReadOnlySpan<byte> AsReadOnlySpan() => _rom.Span; |
| 409 | + return new MemoryBufferWrapper(this, flags); |
| 410 | + } |
403 | 411 |
|
404 | | - public Span<byte> AsSpan() => _memory.HasValue ? _memory.Value.Span : throw new InvalidOperationException("ReadOnlyMemory is not writable"); |
| 412 | + private sealed unsafe class MemoryBufferWrapper : IPythonBuffer { |
| 413 | + private readonly MemoryBufferProtocolWrapper<T> _wrapper; |
| 414 | + private readonly BufferFlags _flags; |
405 | 415 |
|
406 | | - public MemoryHandle Pin() => _rom.Pin(); |
| 416 | + public MemoryBufferWrapper(MemoryBufferProtocolWrapper<T> wrapper, BufferFlags flags) { |
| 417 | + _wrapper = wrapper; |
| 418 | + _flags = flags; |
| 419 | + } |
407 | 420 |
|
408 | | - public int Offset => 0; |
| 421 | + public void Dispose() { } |
409 | 422 |
|
410 | | - public string? Format => _flags.HasFlag(BufferFlags.Format) ? "B" : null; |
| 423 | + public object Object => _wrapper._memory ?? _wrapper._rom; |
411 | 424 |
|
412 | | - public int ItemCount => _rom.Length; |
| 425 | + public bool IsReadOnly => !_wrapper._memory.HasValue; |
413 | 426 |
|
414 | | - public int ItemSize => 1; |
| 427 | + public ReadOnlySpan<byte> AsReadOnlySpan() => MemoryMarshal.Cast<T, byte>(_wrapper._rom.Span); |
415 | 428 |
|
416 | | - public int NumOfDims => 1; |
| 429 | + public Span<byte> AsSpan() |
| 430 | + => _wrapper._memory.HasValue |
| 431 | + ? MemoryMarshal.Cast<T, byte>(_wrapper._memory.Value.Span) |
| 432 | + : throw new InvalidOperationException("ReadOnlyMemory is not writable"); |
417 | 433 |
|
418 | | - public IReadOnlyList<int>? Shape => null; |
| 434 | + public MemoryHandle Pin() => _wrapper._rom.Pin(); |
419 | 435 |
|
420 | | - public IReadOnlyList<int>? Strides => null; |
| 436 | + public int Offset => 0; |
421 | 437 |
|
422 | | - public IReadOnlyList<int>? SubOffsets => null; |
423 | | - } |
| 438 | + public string? Format => _flags.HasFlag(BufferFlags.Format) ? _wrapper._format : null; |
424 | 439 |
|
425 | | - public class MemoryBufferProtocolWrapper : IBufferProtocol { |
426 | | - private readonly ReadOnlyMemory<byte> _rom; |
427 | | - private readonly Memory<byte>? _memory; |
| 440 | + public int ItemCount => _wrapper._rom.Length; |
428 | 441 |
|
429 | | - public MemoryBufferProtocolWrapper(ReadOnlyMemory<byte> memory) { |
430 | | - _rom = memory; |
431 | | - _memory = null; |
432 | | - } |
| 442 | + public int ItemSize => sizeof(T); |
433 | 443 |
|
434 | | - public MemoryBufferProtocolWrapper(Memory<byte> memory) { |
435 | | - _rom = memory; |
436 | | - _memory = memory; |
437 | | - } |
| 444 | + public int NumOfDims => 1; |
438 | 445 |
|
439 | | - public IPythonBuffer? GetBuffer(BufferFlags flags, bool throwOnError) { |
440 | | - if (_memory.HasValue) { |
441 | | - return new MemoryBufferWrapper(_memory.Value, flags); |
442 | | - } |
| 446 | + public IReadOnlyList<int>? Shape => null; |
443 | 447 |
|
444 | | - if (flags.HasFlag(BufferFlags.Writable)) { |
445 | | - if (throwOnError) { |
446 | | - throw Operations.PythonOps.BufferError("ReadOnlyMemory is not writable."); |
447 | | - } |
448 | | - return null; |
449 | | - } |
| 448 | + public IReadOnlyList<int>? Strides => null; |
450 | 449 |
|
451 | | - return new MemoryBufferWrapper(_rom, flags); |
| 450 | + public IReadOnlyList<int>? SubOffsets => null; |
452 | 451 | } |
453 | 452 | } |
454 | 453 | } |
0 commit comments