Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>

<!--required for TFLite/LiteRT style transfer demo -->
<uses-library android:name="libOpenCL.so"
android:required="false"/>
<uses-library android:name="libOpenCL-pixel.so"
android:required="false"/>
</application>

</manifest>
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ androidxTestExtTruth = "1.5.0"
androidxTestRules = "1.5.0"
androidxTestRunner = "1.5.2"
androidxUiAutomator = "2.2.0"
media3 = "1.4.0-rc01"
media3 = "1.5.0"
appcompat = "1.6.1"
material = "1.12.0-beta01"
constraintlayout = "2.1.4"
Expand Down
2 changes: 1 addition & 1 deletion samples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ A sample showcasing how to handle calls with the Jetpack Telecom API
- [TextSpan](user-interface/text/src/main/java/com/example/platform/ui/text/TextSpan.kt):
buildSpannedString is useful for quickly building a rich text.
- [Transformer and TFLite](media/video/src/main/java/com/example/platform/media/video/TransformerTFLite.kt):
This sample demonstrates using Transformer with TFLite by applying a selected art style to a video.
This sample demonstrates using Transformer with TFLite/RTLite by applying a selected art style to a video.
- [UltraHDR Image Capture](camera/camera2/src/main/java/com/example/platform/camera/imagecapture/Camera2UltraHDRCapture.kt):
This sample demonstrates how to capture a 10-bit compressed still image and
- [UltraHDR to HDR Video](media/ultrahdr/src/main/java/com/example/platform/media/ultrahdr/video/UltraHDRToHDRVideo.kt):
Expand Down
29 changes: 26 additions & 3 deletions samples/media/video/download_model.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,39 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
tasks.register('downloadModelFile1', Download) {
task downloadModelFile(type: Download) {
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_int8_prediction_1.tflite'
dest project.ext.ASSET_DIR + '/predict_int8.tflite'
overwrite false
}

task downloadModelFile0(type: Download) {
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_int8_transfer_1.tflite'
dest project.ext.ASSET_DIR + '/transfer_int8.tflite'
overwrite false
}

task downloadModelFile1(type: Download) {
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_fp16_prediction_1.tflite'
dest project.ext.ASSET_DIR + '/predict_float16.tflite'
overwrite false
}

tasks.register('downloadModelFile2', Download) {
task downloadModelFile2(type: Download) {
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_fp16_transfer_1.tflite'
dest project.ext.ASSET_DIR + '/transfer_float16.tflite'
overwrite false
}

preBuild.dependsOn downloadModelFile1, downloadModelFile2
task copyTestModel(type: Copy, dependsOn: downloadModelFile1) {
from project.ext.ASSET_DIR + '/predict_float16.tflite'
into project.ext.TEST_ASSETS_DIR
}

task copyTestModel0(type: Copy, dependsOn: downloadModelFile2) {
from project.ext.ASSET_DIR + '/transfer_float16.tflite'
into project.ext.TEST_ASSETS_DIR
}

preBuild.dependsOn downloadModelFile, downloadModelFile1, downloadModelFile2,
copyTestModel, copyTestModel0
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright 2024 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.example.platform.media.video

import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.Matrix
import androidx.media3.common.GlTextureInfo
import androidx.media3.common.VideoFrameProcessingException
import androidx.media3.common.util.GlRect
import androidx.media3.common.util.GlUtil
import androidx.media3.common.util.Size
import androidx.media3.common.util.UnstableApi
import androidx.media3.common.util.Util
import androidx.media3.effect.ByteBufferGlEffect
import com.google.common.collect.ImmutableMap
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.ListeningExecutorService
import com.google.common.util.concurrent.MoreExecutors
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.InterpreterApi
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.DequantizeOp
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.image.ops.ResizeOp
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.util.concurrent.Future

@UnstableApi
class StyleTransferEffect(context: Context, styleAssetFileName: String) : ByteBufferGlEffect.Processor<Bitmap> {

private val transformInterpreter: InterpreterApi
private val inputTransformTargetHeight: Int
private val inputTransformTargetWidth: Int
private val outputTransformShape: IntArray

private var preProcess: ListeningExecutorService = MoreExecutors.listeningDecorator(
Util.newSingleThreadExecutor("preProcess"))
private var postProcess: ListeningExecutorService = MoreExecutors.listeningDecorator(
Util.newSingleThreadExecutor("postProcess"))
private var tfRun: ListeningExecutorService = MoreExecutors.listeningDecorator(
Util.newSingleThreadExecutor("tfRun"))

private val predictOutput: TensorBuffer

private var inputWidth: Int = 0
private var inputHeight: Int = 0


init {
val options = Interpreter.Options()
val compatibilityList = CompatibilityList()
val gpuDelegateOptions = compatibilityList.bestOptionsForThisDevice
val gpuDelegate = GpuDelegate(gpuDelegateOptions)
options.addDelegate(gpuDelegate)
val predictModel = "predict_float16.tflite"
val transferModel = "transfer_float16.tflite"
val predictInterpreter = Interpreter(FileUtil.loadMappedFile(context, predictModel), options)
transformInterpreter = InterpreterApi.create(FileUtil.loadMappedFile(context, transferModel), options)
val inputPredictTargetHeight = predictInterpreter.getInputTensor(0).shape()[1]
val inputPredictTargetWidth = predictInterpreter.getInputTensor(0).shape()[2]
val outputPredictShape = predictInterpreter.getOutputTensor(0).shape()

inputTransformTargetHeight = transformInterpreter.getInputTensor(0).shape()[1]
inputTransformTargetWidth = transformInterpreter.getInputTensor(0).shape()[2]
outputTransformShape = transformInterpreter.getOutputTensor(0).shape()

val inputStream = context.assets.open(styleAssetFileName)
val styleImage = BitmapFactory.decodeStream(inputStream)
inputStream.close()
val styleTensorImage = getScaledTensorImage(styleImage, inputPredictTargetWidth, inputPredictTargetHeight)
predictOutput = TensorBuffer.createFixedSize(outputPredictShape, DataType.FLOAT32)
predictInterpreter.run(styleTensorImage.buffer, predictOutput.buffer)
}

override fun configure(inputWidth: Int, inputHeight: Int): Size {
this.inputWidth = inputWidth
this.inputHeight = inputHeight
return Size(inputTransformTargetWidth, inputTransformTargetHeight)
}

override fun getScaledRegion(presentationTimeUs: Long): GlRect {
val minSide = minOf(inputWidth, inputHeight)
return GlRect(0, 0, minSide, minSide)
}

override fun processImage(
image: ByteBufferGlEffect.Image,
presentationTimeUs: Long,
): ListenableFuture<Bitmap> {
val tensorImageFuture = preProcess(image)
val tensorBufferFuture = tfRun(tensorImageFuture)
return postProcess(tensorBufferFuture)
}

override fun release() {}

override fun finishProcessingAndBlend(
outputFrame: GlTextureInfo,
presentationTimeUs: Long,
result: Bitmap,
) {
try {
copyBitmapToFbo(result, outputFrame, getScaledRegion(presentationTimeUs))
} catch (e: GlUtil.GlException) {
throw VideoFrameProcessingException.from(e)
}
}

private fun preProcess(image: ByteBufferGlEffect.Image): ListenableFuture<TensorImage> {
return preProcess.submit<TensorImage> {
val bitmap = image.copyToBitmap()
getScaledTensorImage(bitmap, inputTransformTargetWidth, inputTransformTargetHeight)
}
}

private fun tfRun(tensorImageFuture: Future<TensorImage>): ListenableFuture<TensorBuffer> {
return tfRun.submit<TensorBuffer> {
val tensorImage = tensorImageFuture.get()
val outputImage = TensorBuffer.createFixedSize(outputTransformShape, DataType.FLOAT32)

transformInterpreter.runForMultipleInputsOutputs(
arrayOf(tensorImage.buffer, predictOutput.buffer),
ImmutableMap.builder<Int, Any>().put(0, outputImage.buffer).build()
)
outputImage
}
}

private fun postProcess(futureOutputImage: ListenableFuture<TensorBuffer>): ListenableFuture<Bitmap> {
return postProcess.submit<Bitmap> {
val outputImage = futureOutputImage.get()
val imagePostProcessor = ImageProcessor.Builder()
.add(DequantizeOp(0f, 255f))
.build()
val outputTensorImage = TensorImage(DataType.FLOAT32)
outputTensorImage.load(outputImage)
imagePostProcessor.process(outputTensorImage).bitmap
}
}

private fun getScaledTensorImage(bitmap: Bitmap, targetWidth: Int, targetHeight: Int): TensorImage {
val cropSize = minOf(bitmap.width, bitmap.height)
val imageProcessor = ImageProcessor.Builder()
.add(ResizeWithCropOrPadOp(cropSize, cropSize))
.add(ResizeOp(targetHeight, targetWidth, ResizeOp.ResizeMethod.BILINEAR))
.add(NormalizeOp(0f, 255f))
.build()
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(bitmap)
return imageProcessor.process(tensorImage)
}

private fun copyBitmapToFbo(bitmap: Bitmap, textureInfo: GlTextureInfo, rect: GlRect) {
val bitmapToGl = Matrix().apply { setScale(1f, -1f) }
val texId = GlUtil.createTexture(bitmap.width, bitmap.height, false)
val fboId = GlUtil.createFboForTexture(texId)
GlUtil.setTexture(texId,
Bitmap.createBitmap(bitmap, 0, 0, bitmap.width, bitmap.height, bitmapToGl, true))
GlUtil.blitFrameBuffer(fboId, GlRect(0, 0, bitmap.width, bitmap.height), textureInfo.fboId, rect)
GlUtil.deleteTexture(texId)
GlUtil.deleteFbo(fboId)
}
}
Loading
Loading