Skip to content

Commit 20c7a4e

Browse files
authored
Update TFLite implementation with ByteBufferGlEffect (#209)
* Update TFLite implementation with ByteBufferGlEffect * Update media3 version * Update media3 dependency
1 parent 06d584d commit 20c7a4e

File tree

7 files changed

+224
-245
lines changed

7 files changed

+224
-245
lines changed

app/src/main/AndroidManifest.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
<category android:name="android.intent.category.LAUNCHER" />
3535
</intent-filter>
3636
</activity>
37+
38+
<!--required for TFLite/LiteRT style transfer demo -->
39+
<uses-library android:name="libOpenCL.so"
40+
android:required="false"/>
41+
<uses-library android:name="libOpenCL-pixel.so"
42+
android:required="false"/>
3743
</application>
3844

3945
</manifest>

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ androidxTestExtTruth = "1.5.0"
3939
androidxTestRules = "1.5.0"
4040
androidxTestRunner = "1.5.2"
4141
androidxUiAutomator = "2.2.0"
42-
media3 = "1.4.0-rc01"
42+
media3 = "1.5.0"
4343
appcompat = "1.6.1"
4444
material = "1.12.0-beta01"
4545
constraintlayout = "2.1.4"

samples/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ A sample showcasing how to handle calls with the Jetpack Telecom API
117117
- [TextSpan](user-interface/text/src/main/java/com/example/platform/ui/text/TextSpan.kt):
118118
buildSpannedString is useful for quickly building a rich text.
119119
- [Transformer and TFLite](media/video/src/main/java/com/example/platform/media/video/TransformerTFLite.kt):
120-
This sample demonstrates using Transformer with TFLite by applying a selected art style to a video.
120+
This sample demonstrates using Transformer with TFLite/RTLite by applying a selected art style to a video.
121121
- [UltraHDR Image Capture](camera/camera2/src/main/java/com/example/platform/camera/imagecapture/Camera2UltraHDRCapture.kt):
122122
This sample demonstrates how to capture a 10-bit compressed still image and
123123
- [UltraHDR to HDR Video](media/ultrahdr/src/main/java/com/example/platform/media/ultrahdr/video/UltraHDRToHDRVideo.kt):

samples/media/video/download_model.gradle

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,39 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16-
tasks.register('downloadModelFile1', Download) {
16+
task downloadModelFile(type: Download) {
17+
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_int8_prediction_1.tflite'
18+
dest project.ext.ASSET_DIR + '/predict_int8.tflite'
19+
overwrite false
20+
}
21+
22+
task downloadModelFile0(type: Download) {
23+
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_int8_transfer_1.tflite'
24+
dest project.ext.ASSET_DIR + '/transfer_int8.tflite'
25+
overwrite false
26+
}
27+
28+
task downloadModelFile1(type: Download) {
1729
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_fp16_prediction_1.tflite'
1830
dest project.ext.ASSET_DIR + '/predict_float16.tflite'
1931
overwrite false
2032
}
2133

22-
tasks.register('downloadModelFile2', Download) {
34+
task downloadModelFile2(type: Download) {
2335
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_fp16_transfer_1.tflite'
2436
dest project.ext.ASSET_DIR + '/transfer_float16.tflite'
2537
overwrite false
2638
}
2739

28-
preBuild.dependsOn downloadModelFile1, downloadModelFile2
40+
task copyTestModel(type: Copy, dependsOn: downloadModelFile1) {
41+
from project.ext.ASSET_DIR + '/predict_float16.tflite'
42+
into project.ext.TEST_ASSETS_DIR
43+
}
44+
45+
task copyTestModel0(type: Copy, dependsOn: downloadModelFile2) {
46+
from project.ext.ASSET_DIR + '/transfer_float16.tflite'
47+
into project.ext.TEST_ASSETS_DIR
48+
}
49+
50+
preBuild.dependsOn downloadModelFile, downloadModelFile1, downloadModelFile2,
51+
copyTestModel, copyTestModel0
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* Copyright 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.platform.media.video
18+
19+
import android.content.Context
20+
import android.graphics.Bitmap
21+
import android.graphics.BitmapFactory
22+
import android.graphics.Matrix
23+
import androidx.media3.common.GlTextureInfo
24+
import androidx.media3.common.VideoFrameProcessingException
25+
import androidx.media3.common.util.GlRect
26+
import androidx.media3.common.util.GlUtil
27+
import androidx.media3.common.util.Size
28+
import androidx.media3.common.util.UnstableApi
29+
import androidx.media3.common.util.Util
30+
import androidx.media3.effect.ByteBufferGlEffect
31+
import com.google.common.collect.ImmutableMap
32+
import com.google.common.util.concurrent.ListenableFuture
33+
import com.google.common.util.concurrent.ListeningExecutorService
34+
import com.google.common.util.concurrent.MoreExecutors
35+
import org.tensorflow.lite.DataType
36+
import org.tensorflow.lite.Interpreter
37+
import org.tensorflow.lite.InterpreterApi
38+
import org.tensorflow.lite.gpu.CompatibilityList
39+
import org.tensorflow.lite.gpu.GpuDelegate
40+
import org.tensorflow.lite.support.common.FileUtil
41+
import org.tensorflow.lite.support.common.ops.DequantizeOp
42+
import org.tensorflow.lite.support.common.ops.NormalizeOp
43+
import org.tensorflow.lite.support.image.ImageProcessor
44+
import org.tensorflow.lite.support.image.TensorImage
45+
import org.tensorflow.lite.support.image.ops.ResizeOp
46+
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
47+
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
48+
import java.util.concurrent.Future
49+
50+
@UnstableApi
51+
class StyleTransferEffect(context: Context, styleAssetFileName: String) : ByteBufferGlEffect.Processor<Bitmap> {
52+
53+
private val transformInterpreter: InterpreterApi
54+
private val inputTransformTargetHeight: Int
55+
private val inputTransformTargetWidth: Int
56+
private val outputTransformShape: IntArray
57+
58+
private var preProcess: ListeningExecutorService = MoreExecutors.listeningDecorator(
59+
Util.newSingleThreadExecutor("preProcess"))
60+
private var postProcess: ListeningExecutorService = MoreExecutors.listeningDecorator(
61+
Util.newSingleThreadExecutor("postProcess"))
62+
private var tfRun: ListeningExecutorService = MoreExecutors.listeningDecorator(
63+
Util.newSingleThreadExecutor("tfRun"))
64+
65+
private val predictOutput: TensorBuffer
66+
67+
private var inputWidth: Int = 0
68+
private var inputHeight: Int = 0
69+
70+
71+
init {
72+
val options = Interpreter.Options()
73+
val compatibilityList = CompatibilityList()
74+
val gpuDelegateOptions = compatibilityList.bestOptionsForThisDevice
75+
val gpuDelegate = GpuDelegate(gpuDelegateOptions)
76+
options.addDelegate(gpuDelegate)
77+
val predictModel = "predict_float16.tflite"
78+
val transferModel = "transfer_float16.tflite"
79+
val predictInterpreter = Interpreter(FileUtil.loadMappedFile(context, predictModel), options)
80+
transformInterpreter = InterpreterApi.create(FileUtil.loadMappedFile(context, transferModel), options)
81+
val inputPredictTargetHeight = predictInterpreter.getInputTensor(0).shape()[1]
82+
val inputPredictTargetWidth = predictInterpreter.getInputTensor(0).shape()[2]
83+
val outputPredictShape = predictInterpreter.getOutputTensor(0).shape()
84+
85+
inputTransformTargetHeight = transformInterpreter.getInputTensor(0).shape()[1]
86+
inputTransformTargetWidth = transformInterpreter.getInputTensor(0).shape()[2]
87+
outputTransformShape = transformInterpreter.getOutputTensor(0).shape()
88+
89+
val inputStream = context.assets.open(styleAssetFileName)
90+
val styleImage = BitmapFactory.decodeStream(inputStream)
91+
inputStream.close()
92+
val styleTensorImage = getScaledTensorImage(styleImage, inputPredictTargetWidth, inputPredictTargetHeight)
93+
predictOutput = TensorBuffer.createFixedSize(outputPredictShape, DataType.FLOAT32)
94+
predictInterpreter.run(styleTensorImage.buffer, predictOutput.buffer)
95+
}
96+
97+
override fun configure(inputWidth: Int, inputHeight: Int): Size {
98+
this.inputWidth = inputWidth
99+
this.inputHeight = inputHeight
100+
return Size(inputTransformTargetWidth, inputTransformTargetHeight)
101+
}
102+
103+
override fun getScaledRegion(presentationTimeUs: Long): GlRect {
104+
val minSide = minOf(inputWidth, inputHeight)
105+
return GlRect(0, 0, minSide, minSide)
106+
}
107+
108+
override fun processImage(
109+
image: ByteBufferGlEffect.Image,
110+
presentationTimeUs: Long,
111+
): ListenableFuture<Bitmap> {
112+
val tensorImageFuture = preProcess(image)
113+
val tensorBufferFuture = tfRun(tensorImageFuture)
114+
return postProcess(tensorBufferFuture)
115+
}
116+
117+
override fun release() {}
118+
119+
override fun finishProcessingAndBlend(
120+
outputFrame: GlTextureInfo,
121+
presentationTimeUs: Long,
122+
result: Bitmap,
123+
) {
124+
try {
125+
copyBitmapToFbo(result, outputFrame, getScaledRegion(presentationTimeUs))
126+
} catch (e: GlUtil.GlException) {
127+
throw VideoFrameProcessingException.from(e)
128+
}
129+
}
130+
131+
private fun preProcess(image: ByteBufferGlEffect.Image): ListenableFuture<TensorImage> {
132+
return preProcess.submit<TensorImage> {
133+
val bitmap = image.copyToBitmap()
134+
getScaledTensorImage(bitmap, inputTransformTargetWidth, inputTransformTargetHeight)
135+
}
136+
}
137+
138+
private fun tfRun(tensorImageFuture: Future<TensorImage>): ListenableFuture<TensorBuffer> {
139+
return tfRun.submit<TensorBuffer> {
140+
val tensorImage = tensorImageFuture.get()
141+
val outputImage = TensorBuffer.createFixedSize(outputTransformShape, DataType.FLOAT32)
142+
143+
transformInterpreter.runForMultipleInputsOutputs(
144+
arrayOf(tensorImage.buffer, predictOutput.buffer),
145+
ImmutableMap.builder<Int, Any>().put(0, outputImage.buffer).build()
146+
)
147+
outputImage
148+
}
149+
}
150+
151+
private fun postProcess(futureOutputImage: ListenableFuture<TensorBuffer>): ListenableFuture<Bitmap> {
152+
return postProcess.submit<Bitmap> {
153+
val outputImage = futureOutputImage.get()
154+
val imagePostProcessor = ImageProcessor.Builder()
155+
.add(DequantizeOp(0f, 255f))
156+
.build()
157+
val outputTensorImage = TensorImage(DataType.FLOAT32)
158+
outputTensorImage.load(outputImage)
159+
imagePostProcessor.process(outputTensorImage).bitmap
160+
}
161+
}
162+
163+
private fun getScaledTensorImage(bitmap: Bitmap, targetWidth: Int, targetHeight: Int): TensorImage {
164+
val cropSize = minOf(bitmap.width, bitmap.height)
165+
val imageProcessor = ImageProcessor.Builder()
166+
.add(ResizeWithCropOrPadOp(cropSize, cropSize))
167+
.add(ResizeOp(targetHeight, targetWidth, ResizeOp.ResizeMethod.BILINEAR))
168+
.add(NormalizeOp(0f, 255f))
169+
.build()
170+
val tensorImage = TensorImage(DataType.FLOAT32)
171+
tensorImage.load(bitmap)
172+
return imageProcessor.process(tensorImage)
173+
}
174+
175+
private fun copyBitmapToFbo(bitmap: Bitmap, textureInfo: GlTextureInfo, rect: GlRect) {
176+
val bitmapToGl = Matrix().apply { setScale(1f, -1f) }
177+
val texId = GlUtil.createTexture(bitmap.width, bitmap.height, false)
178+
val fboId = GlUtil.createFboForTexture(texId)
179+
GlUtil.setTexture(texId,
180+
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, bitmapToGl, true))
181+
GlUtil.blitFrameBuffer(fboId, GlRect(0, 0, bitmap.width, bitmap.height), textureInfo.fboId, rect)
182+
GlUtil.deleteTexture(texId)
183+
GlUtil.deleteFbo(fboId)
184+
}
185+
}

0 commit comments

Comments
 (0)