-
Notifications
You must be signed in to change notification settings - Fork 234
C# binding into STFT pre-processing #1843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
99687be
4715a9a
cdb6325
7f3a634
dcb2c14
835628a
fbe74ff
febc439
ad4de9b
595f96c
a2f4773
80b9430
9280016
bf29508
a12c530
7a9527d
cac5438
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,209 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Licensed under the MIT License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Collections.Generic; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using System.Runtime.InteropServices; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace Microsoft.ML.OnnxRuntimeGenAI | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static class SignalProcessor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can these APIs be added to the |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Numeric ElementType values matching OgaElementType in the C API: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // enum OgaElementType { undefined=0, float32=1, ..., int64=7, ... } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private const int ET_Float32 = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private const int ET_Int64 = 7; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Thin wrapper around the native OgaSplitSignalSegments. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// All arguments are OgaTensor handles (IntPtr). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static void SplitSignalSegments( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr inputTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr srTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr frameMsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr hopMsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr energyThresholdDbTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr outputTensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int err = NativeMethods.OgaSplitSignalSegments( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| srTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| frameMsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hopMsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| energyThresholdDbTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputTensor); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (err != 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new InvalidOperationException($"OgaSplitSignalSegments failed with error code {err}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Thin wrapper around the native OgaMergeSignalSegments. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// All arguments are OgaTensor handles (IntPtr). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static void MergeSignalSegments( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr segmentsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr mergeGapMsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr outputTensor) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int err = NativeMethods.OgaMergeSignalSegments( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| segmentsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mergeGapMsTensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputTensor); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (err != 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new InvalidOperationException($"OgaMergeSignalSegments failed with error code {err}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Create a tensor view over a managed float[] using OgaCreateTensorFromBuffer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static unsafe IntPtr CreateFloatTensorFromArray(float[] data, long[] shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (data == null) throw new ArgumentNullException(nameof(data)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (shape == null) throw new ArgumentNullException(nameof(shape)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed (float* p = data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Result.VerifySuccess( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NativeMethods.OgaCreateTensorFromBuffer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (IntPtr)p, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (UIntPtr)shape.Length, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (ElementType)ET_Float32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out tensor)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Create a tensor view over a managed long[] using OgaCreateTensorFromBuffer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static unsafe IntPtr CreateInt64TensorFromArray(long[] data, long[] shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nenad1002 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (data == null) throw new ArgumentNullException(nameof(data)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (shape == null) throw new ArgumentNullException(nameof(shape)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed (long* p = data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Result.VerifySuccess( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NativeMethods.OgaCreateTensorFromBuffer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (IntPtr)p, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (UIntPtr)shape.Length, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (ElementType)ET_Int64, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out tensor)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Create an output tensor that points at a caller-owned long[] buffer. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static unsafe IntPtr CreateOutputInt64Tensor(long[] backingBuffer, long rows, long cols) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (backingBuffer == null) throw new ArgumentNullException(nameof(backingBuffer)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (rows * cols > backingBuffer.LongLength) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new ArgumentException("backingBuffer too small for requested shape"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] shape = new long[] { rows, cols }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fixed (long* p = backingBuffer) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Result.VerifySuccess( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NativeMethods.OgaCreateTensorFromBuffer( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (IntPtr)p, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shape, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (UIntPtr)shape.Length, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (ElementType)ET_Int64, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out tensor)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// <summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// Runs STFT over the input signal and finds the areas of high energy with start/end timestamps in ms. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /// </summary> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static (double Start, double End)[] SplitAndMergeSegments( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| float[] inputSignal, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int sampleRate, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int frameMs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int hopMs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| float energyThresholdDb, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int mergeGapMs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (inputSignal == null || inputSignal.Length == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new ArgumentException("Input array cannot be null or empty", nameof(inputSignal)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int MaxSegs = 128; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] splitBacking = new long[MaxSegs * 2]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] mergedBacking = new long[MaxSegs * 2]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| IntPtr input = IntPtr.Zero, sr = IntPtr.Zero, frame = IntPtr.Zero, hop = IntPtr.Zero, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| thr = IntPtr.Zero, splitOut = IntPtr.Zero, mergeGap = IntPtr.Zero, mergedOut = IntPtr.Zero; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] inputShape = new long[] { 1, inputSignal.Length }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input = CreateFloatTensorFromArray(inputSignal, inputShape); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you construct a onnxruntime-genai/src/csharp/Tensor.cs Lines 28 to 45 in 9d3631d
The Here is an example with onnxruntime-genai/test/csharp/TestOnnxRuntimeGenAIAPI.cs Lines 637 to 669 in 9d3631d
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this approach gave me memory corruption as well, there might be an issue with Tensor class, but I can give it another try. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sr = CreateInt64TensorFromArray(new long[] { sampleRate }, new long[] { 1 }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| frame = CreateInt64TensorFromArray(new long[] { frameMs }, new long[] { 1 }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hop = CreateInt64TensorFromArray(new long[] { hopMs }, new long[] { 1 }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| thr = CreateFloatTensorFromArray(new float[] { energyThresholdDb }, new long[] { 1 }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| splitOut = CreateOutputInt64Tensor(splitBacking, MaxSegs, 2); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mergedOut = CreateOutputInt64Tensor(mergedBacking, MaxSegs, 2); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| SplitSignalSegments(input, sr, frame, hop, thr, splitOut); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] splitShapeBuf = new long[2]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Result.VerifySuccess(NativeMethods.OgaTensorGetShape(splitOut, splitShapeBuf, (UIntPtr)splitShapeBuf.Length)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long splitRows = splitShapeBuf[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long splitCols = splitShapeBuf[1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mergeGap = CreateInt64TensorFromArray(new long[] { mergeGapMs }, new long[] { 1 }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MergeSignalSegments(splitOut, mergeGap, mergedOut); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] mergedShapeBuf = new long[2]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Result.VerifySuccess(NativeMethods.OgaTensorGetShape(mergedOut, mergedShapeBuf, (UIntPtr)mergedShapeBuf.Length)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long mergedRowsDbg = mergedShapeBuf[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long mergedColsDbg = mergedShapeBuf[1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long[] shapeBuf = new long[2]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Result.VerifySuccess(NativeMethods.OgaTensorGetShape(mergedOut, shapeBuf, (UIntPtr)shapeBuf.Length)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long mergedRows = shapeBuf[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long mergedCols = shapeBuf[1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (mergedCols != 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new InvalidOperationException($"Expected merged output with 2 columns, got {mergedCols}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Convert to array of start/end tuples. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var result = new List<(double Start, double End)>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < mergedRows; ++i) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long start = mergedBacking[i * 2 + 0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| long end = mergedBacking[i * 2 + 1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (start == 0 && end == 0) continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result.Add((start, end)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result.ToArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finally | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (input != IntPtr.Zero) NativeMethods.OgaDestroyTensor(input); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (sr != IntPtr.Zero) NativeMethods.OgaDestroyTensor(sr); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (frame != IntPtr.Zero) NativeMethods.OgaDestroyTensor(frame); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (hop != IntPtr.Zero) NativeMethods.OgaDestroyTensor(hop); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (thr != IntPtr.Zero) NativeMethods.OgaDestroyTensor(thr); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (splitOut != IntPtr.Zero) NativeMethods.OgaDestroyTensor(splitOut); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (mergeGap != IntPtr.Zero) NativeMethods.OgaDestroyTensor(mergeGap); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (mergedOut != IntPtr.Zero) NativeMethods.OgaDestroyTensor(mergedOut); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,6 +78,36 @@ T* ReturnUnique(std::unique_ptr<U> p) { | |
| return static_cast<T*>(p.release()); | ||
| } | ||
|
|
||
| // Helper function to convert OgaTensor to OrtxTensor, sometimes needed as an input to onnxruntime-extensions methods. | ||
| template <typename T> | ||
| static OrtxTensor* MakeOrtxTensor(OgaTensor* src) { | ||
| if (!src) { | ||
| throw std::runtime_error("Null tensor passed to MakeOrtxTensor"); | ||
| } | ||
|
|
||
| auto* gen = reinterpret_cast<Generators::Tensor*>(src); | ||
| T* data = const_cast<T*>(gen->GetData<T>()); | ||
| std::vector<int64_t> shape = gen->GetShape(); | ||
|
|
||
| auto* ort = new Ort::Custom::Tensor<T>(shape, data); | ||
| return reinterpret_cast<OrtxTensor*>(ort); | ||
| } | ||
|
|
||
nenad1002 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Helper function to convert const OgaTensor to const OrtxTensor, sometimes needed as an input to onnxruntime-extensions methods. | ||
| template <typename T> | ||
| static const OrtxTensor* MakeOrtxTensorConst(const OgaTensor* src) { | ||
| if (!src) { | ||
| throw std::runtime_error("Null tensor passed to MakeOrtxTensorConst"); | ||
| } | ||
|
|
||
| auto* gen = reinterpret_cast<const Generators::Tensor*>(src); | ||
| const T* data = gen->GetData<T>(); | ||
| std::vector<int64_t> shape = gen->GetShape(); | ||
|
|
||
| auto* ort = new Ort::Custom::Tensor<T>(shape, const_cast<T*>(data)); | ||
| return reinterpret_cast<const OrtxTensor*>(ort); | ||
| } | ||
|
|
||
| extern "C" { | ||
|
|
||
| #define OGA_TRY try { | ||
|
|
@@ -888,6 +918,68 @@ OgaResult* OGA_API_CALL OgaProcessorProcessImagesAndAudiosAndPrompts(const OgaMu | |
| OGA_CATCH | ||
| } | ||
|
|
||
| OGA_EXPORT OgaResult* OGA_API_CALL OgaSplitSignalSegments( | ||
| const OgaTensor* input, | ||
| const OgaTensor* sr_tensor, | ||
| const OgaTensor* frame_ms_tensor, | ||
| const OgaTensor* hop_ms_tensor, | ||
| const OgaTensor* energy_threshold_db_tensor, | ||
| OgaTensor* output0) { | ||
| OGA_TRY | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add unit tests for the new APIs in ORT GenAI's C API tests? |
||
| if (!input || !sr_tensor || !frame_ms_tensor || !hop_ms_tensor || | ||
| !energy_threshold_db_tensor || !output0) { | ||
| throw std::runtime_error("Null tensor argument passed to OgaSplitSignalSegments"); | ||
| } | ||
|
|
||
| const OrtxTensor* in_tensor = MakeOrtxTensorConst<float>(input); | ||
| const OrtxTensor* sr_tensor_obj = MakeOrtxTensorConst<int64_t>(sr_tensor); | ||
| const OrtxTensor* frame_tensor = MakeOrtxTensorConst<int64_t>(frame_ms_tensor); | ||
nenad1002 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const OrtxTensor* hop_tensor = MakeOrtxTensorConst<int64_t>(hop_ms_tensor); | ||
| const OrtxTensor* thr_tensor = MakeOrtxTensorConst<float>(energy_threshold_db_tensor); | ||
| OrtxTensor* out_tensor = MakeOrtxTensor<int64_t>(output0); | ||
|
|
||
| extError_t err = OrtxSplitSignalSegments( | ||
| in_tensor, | ||
| sr_tensor_obj, | ||
| frame_tensor, | ||
| hop_tensor, | ||
| thr_tensor, | ||
| out_tensor); | ||
|
|
||
| if (err != kOrtxOK) { | ||
| throw std::runtime_error(OrtxGetLastErrorMessage()); | ||
| } | ||
| return nullptr; | ||
|
|
||
| OGA_CATCH | ||
| } | ||
|
|
||
| OGA_EXPORT OgaResult* OGA_API_CALL OgaMergeSignalSegments( | ||
| const OgaTensor* segments_tensor, | ||
| const OgaTensor* merge_gap_ms_tensor, | ||
| OgaTensor* output0) { | ||
| OGA_TRY | ||
| if (!segments_tensor || !merge_gap_ms_tensor || !output0) { | ||
| throw std::runtime_error("Null tensor argument passed to OgaMergeSignalSegments"); | ||
| } | ||
|
|
||
| const OrtxTensor* seg_tensor = MakeOrtxTensorConst<int64_t>(segments_tensor); | ||
| const OrtxTensor* gap_tensor = MakeOrtxTensorConst<int64_t>(merge_gap_ms_tensor); | ||
| OrtxTensor* out_tensor = MakeOrtxTensor<int64_t>(output0); | ||
|
|
||
| extError_t err = OrtxMergeSignalSegments( | ||
| seg_tensor, | ||
| gap_tensor, | ||
| out_tensor); | ||
|
|
||
| if (err != kOrtxOK) { | ||
| throw std::runtime_error(OrtxGetLastErrorMessage()); | ||
| } | ||
| return nullptr; | ||
|
|
||
| OGA_CATCH | ||
| } | ||
|
|
||
| OgaResult* OGA_API_CALL OgaCreateStringArray(OgaStringArray** out) { | ||
| OGA_TRY | ||
| *out = ReturnUnique<OgaStringArray>(std::make_unique<std::vector<std::string>>()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -705,6 +705,19 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImagesAndAudios(const OgaM | |
| */ | ||
| OGA_EXPORT OgaResult* OGA_API_CALL OgaProcessorProcessImagesAndAudiosAndPrompts(const OgaMultiModalProcessor*, const OgaStringArray* prompts, const OgaImages* images, const OgaAudios* audios, OgaNamedTensors** input_tensors); | ||
|
|
||
| OGA_EXPORT OgaResult* OGA_API_CALL OgaSplitSignalSegments( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add C# API tests that call these C APIs? |
||
| const OgaTensor* input, | ||
| const OgaTensor* sr_tensor, | ||
| const OgaTensor* frame_ms_tensor, | ||
| const OgaTensor* hop_ms_tensor, | ||
| const OgaTensor* energy_threshold_db_tensor, | ||
| OgaTensor* output0); | ||
|
|
||
| OGA_EXPORT OgaResult* OGA_API_CALL OgaMergeSignalSegments( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add comments above these new APIs to explain their usage and their parameters. |
||
| const OgaTensor* segments_tensor, | ||
| const OgaTensor* merge_gap_ms_tensor, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the type of the object is an |
||
| OgaTensor* output0); | ||
|
|
||
| /** Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString | ||
| */ | ||
| OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer*, const int32_t* tokens, size_t token_count, const char** out_string); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we follow the same tab spacings as the other native method APIs listed in this file for the new APIs?