-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[java] Adds MemorySegment support to allow OrtValues greater than 2GB #26911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. | ||
| */ | ||
| static OnnxTensor createTensor( | ||
| OrtEnvironment env, |
Check notice
Code scanning / CodeQL
Useless parameter Note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we comment out the name, or remove the parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It enforces in the type system that the OrtEnvironment has been created before creating any tensors, and I've seen weird error messages if the user creates them the other way around. Java 8 doesn't have a way to specify this variable is unused (e.g. _ as a variable name), that only came in in Java 21.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am afraid it does not enforce anything. I can simply pass null as a value. People will discover this code and will use the opportunity and then it would be hard to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'll put a null check in then.
| } | ||
| } | ||
|
|
||
| public void close() { |
Check notice
Code scanning / CodeQL
Missing Override annotation Note test
AutoCloseable.close
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| OnnxTensor tensor = | ||
| OnnxTensor.createTensorFromMemorySegment(env, notASegment, shape, OnnxJavaType.FLOAT); |
Check notice
Code scanning / CodeQL
Unread local variable Note test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method throws an exception when called with these arguments and the test is checking that it throws the right exception with the right message.
| OnnxTensor.createTensorFromMemorySegment(env, segment.get(), shape, OnnxJavaType.FLOAT); | ||
|
|
||
| try { | ||
| FloatBuffer fb = bigTensor.getFloatBuffer(); |
Check notice
Code scanning / CodeQL
Unread local variable Note test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method throws an exception when called in this state and the test is checking that it throws the right exception with the right message.
| } | ||
|
|
||
| try { | ||
| float[][] arr = (float[][]) bigTensor.getValue(); |
Check notice
Code scanning / CodeQL
Unread local variable Note test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method throws an exception when called in this state and the test is checking that it throws the right exception with the right message.
|
|
||
| /** | ||
| * Wrapper for java.lang.foreign.MemorySegment instances which uses reflection to access the methods | ||
| * so it can be compiled on Java 21 and earlier. Requires Java 22 or newer to use MemorySegments, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is not clear from documentation for two versions of this class:
- If the version of Java is 22, does it mean that
java.lang.foreign.MemorySegmentis guaranteed to be supported? If so, then why we need reflection? - If the answer to the above is no, then would it not make sense to create only one version of the class that would check both Java version and the availability of the feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's on Java 22 or newer then it is guaranteed to be supported. However Java doesn't have conditional compilation, and the multi-release jar feature which allows different code to be loaded depending on the Java version has poor tooling support and requires that every version have the same public API which make it not suitable for this use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditional compilation is not the only way to address the issue. In fact, since we publish binaries, it would not be applicable anyway.
One way to address this is to have difference private implementation classes behind MemorySegment facade and instantiate one depending on which version and environment we run in.
However, my question remains. Suppose I am a developer who is writing code that is supposed to work under the following circumstances:
- Android
- Java version < 22
- Java version >=22
How do I write it in way that is convenient and operational and transparent? This is the key design question.
Here is one more specific: How do I avoid changing my client code and it would automatically take advantage of big tensors, so I do not have to do anything.
If, we are not in the environment that does not support big tensors, everything functions as before UNLESS the tensor is too big.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what "take advantage of big tensors" means? If the user wants to interact with a big tensor in Java they need to get it into a Java side object rather than an ORT native object, and there are no Java side objects big enough until Java 22. If they want to use a big tensor output by one model and pass it into another model without looking at it, they can already do that and nothing changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am talking about the right abstraction on top whatever Java provides and maybe does not, to expose the right types, right sizes and right buffers. I think it is possible to do given the AI capabilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I still do not understand why we need a separate empty shell class for Android. Would not simply fail to load as in this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some Android versions don't have support for MethodHandle and throw if you try to use them. I'm not sure if ORT targets a new enough version of Android to avoid this problem, but if it does then I can remove the Android side bits for this and also for the FP16 conversions which use similar tricks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am talking about the right abstraction on top whatever Java provides and maybe does not, to expose the right types, right sizes and right buffers. I think it is possible to do given the AI capabilities.
I am wary of writing a Java tensor type that only lives in this library and is disconnected from the ORT native code. Other Java ML libraries also have those and they all have no interop between those libraries aside from JDK classes like ByteBuffer and MemorySegment because the JDK does not provide useful interfaces for tensor types to implement. The choices I've been making with ORT are to allow users to bring their data in whatever form they have it and it's their job to get it into a JDK class which we consume directly without copies (which is typically where their data already is because those JDK classes underlie all fast off heap memory in Java and so they are underneath whatever user tensor abstraction they already have).
| | NoSuchMethodException | ||
| | ClassNotFoundException | ||
| | NoSuchFieldException e) { | ||
| logger.fine("Running on Java 21 or earlier, MemorySegment not available"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored the lookup logic here to make it smaller.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question was, can you call some API, find out the Java version, simply compare an integer and not proceed any further?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Java 8 I can parse the version string out of the java.version system property, but that's more code than the Class.forName and I'll still need the try catch around the Class.forName and the method handle lookups as those throw checked exceptions. So it'll be longer.
| tmpLayout = null; | ||
| } catch (Throwable e) { | ||
| logger.severe( | ||
| "Failed to load float value layout, while other Java 22 features were available."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This exception is only thrown in the case where we failed to execute the method to create the float value layout, and I can't think of a case where that could actually happen. However the invokeExact call declares that it throws Throwable so I need to catch it here. I've refactored this check a bit now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The message is confusing. There are no other Java 22 features that are supported in the code, otherwise we would have more checks for Java version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this has been fixed in the refactor of the method handle construction.
| this.segment = segment; | ||
| } else { | ||
| throw new IllegalArgumentException( | ||
| "Segment argument was not a java.lang.foreign.MemorySegment, found " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's probably unnecessary safety as it shouldn't be possible to manufacture one of these without doing more nasty reflection tricks or using JNI to make an instance without going through the constructor.
| * native code encountered an error. | ||
| */ | ||
| public ByteBuffer getByteBuffer() { | ||
| public ByteBuffer getByteBuffer() throws OrtException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this is a bug on my part. That method has always been able to throw OrtException as the exception can be created in JNI by this line, but I forgot to add it to the Java signature here so it was never checked.
Now I think the underlying GetTensorMutableData call can't currently return anything other than ORT_OK, but I don't watch PRs that modify the implementation of the C API, so it could do in the future. If we can guarantee that it'll never return anything other than ORT_OK then I can take it back out and add a comment to that effect in the Java code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can rely on the ABI guarantees. However, if the ABI signature returns an OrtStatus, then one should be prepared to handle an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It returns an OrtStatus, and everywhere else code that touches one of those is marked throws OrtException.
| * native code encountered an error. | ||
| */ | ||
| public FloatBuffer getFloatBuffer() { | ||
| public FloatBuffer getFloatBuffer() throws OrtException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| try { | ||
| return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); | ||
| } catch (IllegalArgumentException e) { | ||
| // thrown by the byte buffer constructor if the tensor is bigger than Integer.MAX_VALUE. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The IllegalArgumentException that is thrown is thrown by the JDK itself from the JNI method that wraps a direct byte buffer around a pointer. I could modify the JNI logic to check how large the pointer is before calling the direct byte buffer constructor and throw an OrtException, but I'd have to construct the string error message down in C and I try to avoid doing string handling in C if I can.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, you can return a meaningful error code and deal with strings in Java code. This would also help to create a transparent API on the Java side that would automatically handle big buffers without need to introduce new public API classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean? It would need to be a different return type on the Java side to send back more than 2GB, and we can't bump the version to Java 25 which is the LTS version which supports this. Android doesn't have MemorySegment and I don't think it's going to get it any time soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can create/call a native method to find out from ORT the size of the date and then bail early.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
| */ | ||
| public Object getMemorySegment() throws OrtException { | ||
| long[] info = getSegmentPointer(OnnxRuntime.ortApiHandle, nativeHandle); | ||
| MemorySegmentShim shim = new MemorySegmentShim(info[0], info[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MemorySegmentShim shim = new MemorySegmentShim(info[0], info[1])
This will throw if MemorySegment is not available.
Is there a better way to discover the fact so the client code can choose a different path of execution? Handling exceptions to alter the flow of execution is less convenient than if/else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My expectation is that all user code that wants to call this ORT code will have the concrete MemorySegment type in it somewhere so it'll have to compile and run with Java 22. Then all this will work and it won't throw exceptions. It's only if the user also can't control what JVM they are running on but wants to selectively run things through MemorySegment if they are running on a newer JVM. But if they did that they still wouldn't be able to make 2GB+ tensors on Java 21 or earlier and the user code logic would have to be even more contorted and full of reflection than this is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what users want is to have their code continue as before, with an added benefit of large tensors behind the scenes if the environment supports it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately they can't have that, because they can't get the value out of a big OnnxTensor into Java or construct a big tensor from a Java side value without being on a new Java version, and the big values have to use a different type because MemorySegment and ByteBuffer don't share any super type other than Object.
If we wanted to rev the Java version to 25 we could only use MemorySegment the way that OpenJDK's ONNX code transformation experiment does, but that would fork the Java & Android APIs, and also abandon everyone who hadn't moved to 25 yet (which only came out September last year).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have a problem with the API. It does not expose the right data types. For example, the native size of the tensor in ORT is expressed as 64-bit integer, but the public API exposed primitives are dealing with int as sizes. This is wrong.
I suggest we re-think the API at this point and create a new one, that would expose the right data types and then transparently deals with both small and big tensors w/o user being concerned about the environment it runs in and the size of the tensor.
Is this not the premise of Java anyway?
| if (!validateShape() && numElements != 0) { | ||
| if ((!validateShape() && numElements != 0) || (numElements * type.size >= Integer.MAX_VALUE)) { | ||
| throw new OrtException( | ||
| "This tensor is not representable in Java, it's too big - shape = " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point, I've updated the message.
| * @throws IllegalArgumentException If the MemorySegment is not on the native heap. | ||
| * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. | ||
| */ | ||
| public static OnnxTensor createTensorFromMemorySegment( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not, there's no conditional compilation or method visibility. Using a multi-release jar which allows different implementations for different Java versions also requires that the public API of all the classes is the same, so even if I did this the more complicated way it would still be visible but unusable in earlier Java versions.
yuslepukhin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕐
|
Please, address and/or comment on the Copilot comments. |
Will do. |
| import org.junit.jupiter.api.condition.JRE; | ||
|
|
||
| /** Tests for interop with Java 22's MemorySegments. */ | ||
| @EnabledForJreRange(min = JRE.JAVA_22) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a conditional compilation or something like that.
From the example of this test, the user would have to create different binaries for Android and different versions of Java. Can we avoid that?
…o OnnxTensor.createFromMemorySegment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
Description
Adds a method to create
OnnxTensorinstances backed by Java 22'sMemorySegment, and also a method to get a value out from aOnnxTensorin aMemorySegment. This is done via reflection over the JDK classes to determine if theMemorySegmentclass is available, and if not it throwsUnsupportedOperationException. Consequently this code compiles fine on Java 8, but theMemorySegmentbased operations are only available when ORT is running on Java 22 or newer. It has a different shim for Android which unconditionally throwsUnsupportedOperationExceptionas Android doesn't support the Java 22 FFM API.Due to the reflection the methods which interact with
MemorySegmentacceptObjectand do a runtime type check that the supplied argument is actually aMemorySegment. It means that the compile time type check that would ordinarily occur in user code is shifted to runtime, but it's still type safe and checked.It would be good to update the CI to run the Java tests on Java 22 or newer (e.g. on Java 25, the LTS version which includes Java 22 features) as well as at least one test on Java 21 or older, but I'm not sure what the Azure CI environment is like so I might need some assistance to do that.
There are two related bug fixes in here as well, the first is that
getBufferRefused to throwNullPointerExceptionif it was called on a tensor that wasn't backed by a buffer, that was an oversight when I modified the duplicate call, and the second is that theget<type>Buffercalls can throwOrtExceptionup from the JNI code and so should have been taggedthrows OrtException. This latter issue is pretty unlikely as ORT doesn't error from that method very often otherwise we'd have seen issues.Motivation and Context
Java's byte buffers and arrays are limited to a maximum of 2^31 - 1 elements as they are indexed by a Java int. To load some models with external initializers we need to be able to create tensors with more elements (e.g. to hold the embedding matrix for some LLMs), which is impossible with the current code.