Skip to content

Commit f5de3bc

Browse files
committed
Use 4-byte LPWStrs on non-Windows platforms
1 parent 0977676 commit f5de3bc

File tree

3 files changed

+146
-28
lines changed

3 files changed

+146
-28
lines changed

src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class TestSilkMarshal
1818
};
1919

2020
[Fact]
21-
public unsafe void TestEncodingLPWStr()
21+
public unsafe void TestEncodingToLPWStr()
2222
{
2323
var input = "Hello world";
2424

@@ -30,7 +30,7 @@ public unsafe void TestEncodingLPWStr()
3030
Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr));
3131

3232
// Use short for comparison
33-
Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }, new Span<short>((void*)pointer, input.Length));
33+
Assert.Equal(new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span<short>((void*)pointer, input.Length + 1));
3434
}
3535
else
3636
{
@@ -39,7 +39,33 @@ public unsafe void TestEncodingLPWStr()
3939
Assert.Equal(input.Length, (int)SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr));
4040

4141
// Use int for comparison
42-
Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 }, new Span<int>((void*)pointer, input.Length));
42+
Assert.Equal(new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x00 }, new Span<int>((void*)pointer, input.Length + 1));
43+
}
44+
}
45+
46+
[Fact]
47+
public unsafe void TestEncodingFromLPWStr()
48+
{
49+
var expected = "Hello world";
50+
51+
// LPWStr is 2 bytes on Windows, 4 bytes elsewhere (usually)
52+
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
53+
{
54+
var characters = new short[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 };
55+
fixed (short* pCharacters = characters)
56+
{
57+
var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr);
58+
Assert.Equal(expected, output);
59+
}
60+
}
61+
else
62+
{
63+
var characters = new int[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64 };
64+
fixed (int* pCharacters = characters)
65+
{
66+
var output = SilkMarshal.PtrToString((nint)pCharacters, NativeStringEncoding.LPWStr);
67+
Assert.Equal(expected, output);
68+
}
4369
}
4470
}
4571

src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ public enum NativeStringEncoding
99
LPStr = UnmanagedType.LPStr,
1010
LPTStr = UnmanagedType.LPTStr,
1111
LPUTF8Str = UnmanagedType.LPUTF8Str,
12+
/// <summary>
13+
/// On Windows, a 2-byte, null-terminated Unicode character string. On other platforms, each character will be 4 bytes instead.
14+
/// </summary>
1215
LPWStr = UnmanagedType.LPWStr,
1316
WinString = UnmanagedType.WinString,
1417
Ansi = LPStr,

src/Core/Silk.NET.Core/Native/SilkMarshal.cs

Lines changed: 114 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na
144144
NativeStringEncoding.BStr => -1,
145145
NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str
146146
=> (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1,
147-
NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2,
147+
NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2,
148+
NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 4,
148149
_ => -1
149150
};
150151

@@ -198,19 +199,35 @@ public static unsafe int StringIntoSpan
198199
span[convertedBytes] = 0;
199200
return ++convertedBytes;
200201
}
201-
case NativeStringEncoding.LPWStr:
202+
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
202203
{
203204
fixed (char* firstChar = input)
205+
fixed (byte* bytes = span)
204206
{
205-
fixed (byte* bytes = span)
206-
{
207-
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
208-
((char*)bytes)[input.Length] = default;
209-
}
207+
Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2);
208+
((char*)bytes)[input.Length] = default;
210209
}
211210

212211
return input.Length + 1;
213212
}
213+
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
214+
{
215+
fixed (char* firstChar = input)
216+
fixed (byte* bytes = span)
217+
{
218+
var maxLength = span.Length / 2;
219+
var i = 0;
220+
while (firstChar[i] != 0 && i < maxLength - 1)
221+
{
222+
((uint*)bytes)[i] = firstChar[i];
223+
i++;
224+
}
225+
226+
((uint*)bytes)[i] = default;
227+
228+
return i * 4;
229+
}
230+
}
214231
default:
215232
{
216233
ThrowInvalidEncoding<GlobalMemory>();
@@ -238,7 +255,7 @@ public static nint AllocateString(int length, NativeStringEncoding encoding = Na
238255
NativeStringEncoding.LPWStr => Allocate(length),
239256
_ => ThrowInvalidEncoding<nint>()
240257
};
241-
258+
242259
/// <summary>
243260
/// Free a string pointer
244261
/// </summary>
@@ -311,7 +328,28 @@ static unsafe string BStrToString(nint ptr)
311328
=> new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char)));
312329

313330
static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr);
314-
static unsafe string WideToString(nint ptr) => new string((char*) ptr);
331+
332+
static unsafe string WideToString(nint ptr)
333+
{
334+
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
335+
{
336+
return new string((char*) ptr);
337+
}
338+
else
339+
{
340+
var length = StringLength(ptr, NativeStringEncoding.LPWStr);
341+
var characters = new ushort[length];
342+
for (var i = 0; i < (uint)length; i++)
343+
{
344+
characters[i] = (ushort)((uint*)ptr)[i];
345+
}
346+
347+
fixed (ushort* pCharacters = characters)
348+
{
349+
return new string((char*)pCharacters);
350+
}
351+
}
352+
};
315353
}
316354

317355
/// <summary>
@@ -524,15 +562,41 @@ Func<nint, string> customUnmarshaller
524562
/// </remarks>
525563
#if NET6_0_OR_GREATER
526564
[MethodImpl(MethodImplOptions.AggressiveInlining)]
527-
public static unsafe nuint StringLength(
565+
public static unsafe nuint StringLength
566+
(
528567
nint ptr,
529568
NativeStringEncoding encoding = NativeStringEncoding.Ansi
530-
) =>
531-
(nuint)(
532-
encoding == NativeStringEncoding.LPWStr
533-
? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length
534-
: MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length
535-
);
569+
)
570+
{
571+
switch (encoding)
572+
{
573+
default:
574+
{
575+
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length;
576+
}
577+
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
578+
{
579+
return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length;
580+
}
581+
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
582+
{
583+
// No int overload for CreateReadOnlySpanFromNullTerminated
584+
if (ptr == 0)
585+
{
586+
return 0;
587+
}
588+
589+
nuint length = 0;
590+
while (((uint*) ptr)![length] != 0)
591+
{
592+
length++;
593+
}
594+
595+
return length;
596+
}
597+
}
598+
}
599+
536600
#else
537601
public static unsafe nuint StringLength(
538602
nint ptr,
@@ -543,15 +607,40 @@ public static unsafe nuint StringLength(
543607
{
544608
return 0;
545609
}
546-
nuint ret;
547-
for (
548-
ret = 0;
549-
encoding == NativeStringEncoding.LPWStr
550-
? ((char*)ptr)![ret] != 0
551-
: ((byte*)ptr)![ret] != 0;
552-
ret++
553-
) { }
554-
return ret;
610+
611+
nuint length = 0;
612+
switch (encoding)
613+
{
614+
default:
615+
{
616+
while (((byte*) ptr)![length] != 0)
617+
{
618+
length++;
619+
}
620+
621+
break;
622+
}
623+
case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
624+
{
625+
while (((char*) ptr)![length] != 0)
626+
{
627+
length++;
628+
}
629+
630+
break;
631+
}
632+
case NativeStringEncoding.LPWStr when !RuntimeInformation.IsOSPlatform(OSPlatform.Windows):
633+
{
634+
while (((uint*) ptr)![length] != 0)
635+
{
636+
length++;
637+
}
638+
639+
break;
640+
}
641+
}
642+
643+
return length;
555644
}
556645
#endif
557646

0 commit comments

Comments
 (0)