1
+ /*
2
+ * Copyright 2024 The Android Open Source Project
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * https://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ package com.example.platform.media.video
18
+
19
+ import android.content.Context
20
+ import android.graphics.Bitmap
21
+ import android.graphics.BitmapFactory
22
+ import android.graphics.Matrix
23
+ import androidx.media3.common.GlTextureInfo
24
+ import androidx.media3.common.VideoFrameProcessingException
25
+ import androidx.media3.common.util.GlRect
26
+ import androidx.media3.common.util.GlUtil
27
+ import androidx.media3.common.util.Size
28
+ import androidx.media3.common.util.UnstableApi
29
+ import androidx.media3.common.util.Util
30
+ import androidx.media3.effect.ByteBufferGlEffect
31
+ import com.google.common.collect.ImmutableMap
32
+ import com.google.common.util.concurrent.ListenableFuture
33
+ import com.google.common.util.concurrent.ListeningExecutorService
34
+ import com.google.common.util.concurrent.MoreExecutors
35
+ import org.tensorflow.lite.DataType
36
+ import org.tensorflow.lite.Interpreter
37
+ import org.tensorflow.lite.InterpreterApi
38
+ import org.tensorflow.lite.gpu.CompatibilityList
39
+ import org.tensorflow.lite.gpu.GpuDelegate
40
+ import org.tensorflow.lite.support.common.FileUtil
41
+ import org.tensorflow.lite.support.common.ops.DequantizeOp
42
+ import org.tensorflow.lite.support.common.ops.NormalizeOp
43
+ import org.tensorflow.lite.support.image.ImageProcessor
44
+ import org.tensorflow.lite.support.image.TensorImage
45
+ import org.tensorflow.lite.support.image.ops.ResizeOp
46
+ import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp
47
+ import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
48
+ import java.util.concurrent.Future
49
+
50
+ @UnstableApi
51
+ class StyleTransferEffect (context : Context , styleAssetFileName : String ) : ByteBufferGlEffect.Processor<Bitmap> {
52
+
53
+ private val transformInterpreter: InterpreterApi
54
+ private val inputTransformTargetHeight: Int
55
+ private val inputTransformTargetWidth: Int
56
+ private val outputTransformShape: IntArray
57
+
58
+ private var preProcess: ListeningExecutorService = MoreExecutors .listeningDecorator(
59
+ Util .newSingleThreadExecutor(" preProcess" ))
60
+ private var postProcess: ListeningExecutorService = MoreExecutors .listeningDecorator(
61
+ Util .newSingleThreadExecutor(" postProcess" ))
62
+ private var tfRun: ListeningExecutorService = MoreExecutors .listeningDecorator(
63
+ Util .newSingleThreadExecutor(" tfRun" ))
64
+
65
+ private val predictOutput: TensorBuffer
66
+
67
+ private var inputWidth: Int = 0
68
+ private var inputHeight: Int = 0
69
+
70
+
71
+ init {
72
+ val options = Interpreter .Options ()
73
+ val compatibilityList = CompatibilityList ()
74
+ val gpuDelegateOptions = compatibilityList.bestOptionsForThisDevice
75
+ val gpuDelegate = GpuDelegate (gpuDelegateOptions)
76
+ options.addDelegate(gpuDelegate)
77
+ val predictModel = " predict_float16.tflite"
78
+ val transferModel = " transfer_float16.tflite"
79
+ val predictInterpreter = Interpreter (FileUtil .loadMappedFile(context, predictModel), options)
80
+ transformInterpreter = InterpreterApi .create(FileUtil .loadMappedFile(context, transferModel), options)
81
+ val inputPredictTargetHeight = predictInterpreter.getInputTensor(0 ).shape()[1 ]
82
+ val inputPredictTargetWidth = predictInterpreter.getInputTensor(0 ).shape()[2 ]
83
+ val outputPredictShape = predictInterpreter.getOutputTensor(0 ).shape()
84
+
85
+ inputTransformTargetHeight = transformInterpreter.getInputTensor(0 ).shape()[1 ]
86
+ inputTransformTargetWidth = transformInterpreter.getInputTensor(0 ).shape()[2 ]
87
+ outputTransformShape = transformInterpreter.getOutputTensor(0 ).shape()
88
+
89
+ val inputStream = context.assets.open(styleAssetFileName)
90
+ val styleImage = BitmapFactory .decodeStream(inputStream)
91
+ inputStream.close()
92
+ val styleTensorImage = getScaledTensorImage(styleImage, inputPredictTargetWidth, inputPredictTargetHeight)
93
+ predictOutput = TensorBuffer .createFixedSize(outputPredictShape, DataType .FLOAT32 )
94
+ predictInterpreter.run (styleTensorImage.buffer, predictOutput.buffer)
95
+ }
96
+
97
+ override fun configure (inputWidth : Int , inputHeight : Int ): Size {
98
+ this .inputWidth = inputWidth
99
+ this .inputHeight = inputHeight
100
+ return Size (inputTransformTargetWidth, inputTransformTargetHeight)
101
+ }
102
+
103
+ override fun getScaledRegion (presentationTimeUs : Long ): GlRect {
104
+ val minSide = minOf(inputWidth, inputHeight)
105
+ return GlRect (0 , 0 , minSide, minSide)
106
+ }
107
+
108
+ override fun processImage (
109
+ image : ByteBufferGlEffect .Image ,
110
+ presentationTimeUs : Long ,
111
+ ): ListenableFuture <Bitmap > {
112
+ val tensorImageFuture = preProcess(image)
113
+ val tensorBufferFuture = tfRun(tensorImageFuture)
114
+ return postProcess(tensorBufferFuture)
115
+ }
116
+
117
+ override fun release () {}
118
+
119
+ override fun finishProcessingAndBlend (
120
+ outputFrame : GlTextureInfo ,
121
+ presentationTimeUs : Long ,
122
+ result : Bitmap ,
123
+ ) {
124
+ try {
125
+ copyBitmapToFbo(result, outputFrame, getScaledRegion(presentationTimeUs))
126
+ } catch (e: GlUtil .GlException ) {
127
+ throw VideoFrameProcessingException .from(e)
128
+ }
129
+ }
130
+
131
+ private fun preProcess (image : ByteBufferGlEffect .Image ): ListenableFuture <TensorImage > {
132
+ return preProcess.submit<TensorImage > {
133
+ val bitmap = image.copyToBitmap()
134
+ getScaledTensorImage(bitmap, inputTransformTargetWidth, inputTransformTargetHeight)
135
+ }
136
+ }
137
+
138
+ private fun tfRun (tensorImageFuture : Future <TensorImage >): ListenableFuture <TensorBuffer > {
139
+ return tfRun.submit<TensorBuffer > {
140
+ val tensorImage = tensorImageFuture.get()
141
+ val outputImage = TensorBuffer .createFixedSize(outputTransformShape, DataType .FLOAT32 )
142
+
143
+ transformInterpreter.runForMultipleInputsOutputs(
144
+ arrayOf(tensorImage.buffer, predictOutput.buffer),
145
+ ImmutableMap .builder<Int , Any >().put(0 , outputImage.buffer).build()
146
+ )
147
+ outputImage
148
+ }
149
+ }
150
+
151
+ private fun postProcess (futureOutputImage : ListenableFuture <TensorBuffer >): ListenableFuture <Bitmap > {
152
+ return postProcess.submit<Bitmap > {
153
+ val outputImage = futureOutputImage.get()
154
+ val imagePostProcessor = ImageProcessor .Builder ()
155
+ .add(DequantizeOp (0f , 255f ))
156
+ .build()
157
+ val outputTensorImage = TensorImage (DataType .FLOAT32 )
158
+ outputTensorImage.load(outputImage)
159
+ imagePostProcessor.process(outputTensorImage).bitmap
160
+ }
161
+ }
162
+
163
+ private fun getScaledTensorImage (bitmap : Bitmap , targetWidth : Int , targetHeight : Int ): TensorImage {
164
+ val cropSize = minOf(bitmap.width, bitmap.height)
165
+ val imageProcessor = ImageProcessor .Builder ()
166
+ .add(ResizeWithCropOrPadOp (cropSize, cropSize))
167
+ .add(ResizeOp (targetHeight, targetWidth, ResizeOp .ResizeMethod .BILINEAR ))
168
+ .add(NormalizeOp (0f , 255f ))
169
+ .build()
170
+ val tensorImage = TensorImage (DataType .FLOAT32 )
171
+ tensorImage.load(bitmap)
172
+ return imageProcessor.process(tensorImage)
173
+ }
174
+
175
+ private fun copyBitmapToFbo (bitmap : Bitmap , textureInfo : GlTextureInfo , rect : GlRect ) {
176
+ val bitmapToGl = Matrix ().apply { setScale(1f , - 1f ) }
177
+ val texId = GlUtil .createTexture(bitmap.width, bitmap.height, false )
178
+ val fboId = GlUtil .createFboForTexture(texId)
179
+ GlUtil .setTexture(texId,
180
+ Bitmap .createBitmap(bitmap, 0 , 0 , bitmap.width, bitmap.height, bitmapToGl, true ))
181
+ GlUtil .blitFrameBuffer(fboId, GlRect (0 , 0 , bitmap.width, bitmap.height), textureInfo.fboId, rect)
182
+ GlUtil .deleteTexture(texId)
183
+ GlUtil .deleteFbo(fboId)
184
+ }
185
+ }
0 commit comments