2222import io .apitally .common .dto .Response ;
2323import jakarta .servlet .FilterChain ;
2424import jakarta .servlet .ServletException ;
25+ import jakarta .servlet .ServletOutputStream ;
26+ import jakarta .servlet .WriteListener ;
2527import jakarta .servlet .http .HttpServletRequest ;
2628import jakarta .servlet .http .HttpServletResponse ;
29+ import jakarta .servlet .http .HttpServletResponseWrapper ;
2730import jakarta .validation .ConstraintViolation ;
2831import jakarta .validation .ConstraintViolationException ;
2932
@@ -58,14 +61,17 @@ protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull Ht
5861 ContentCachingResponseWrapper cachingResponse = shouldCacheResponse
5962 ? new ContentCachingResponseWrapper (response )
6063 : null ;
64+ CountingResponseWrapper countingResponse = cachingResponse == null
65+ ? new CountingResponseWrapper (response )
66+ : null ;
6167
6268 Exception exception = null ;
6369 final long startTime = System .currentTimeMillis ();
6470
6571 try {
6672 filterChain .doFilter (
6773 cachingRequest != null ? cachingRequest : request ,
68- cachingResponse != null ? cachingResponse : response );
74+ cachingResponse != null ? cachingResponse : countingResponse );
6975 } catch (Exception e ) {
7076 exception = e ;
7177 throw e ;
@@ -92,9 +98,11 @@ protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull Ht
9298 final long requestSize = requestContentLength >= 0 ? requestContentLength
9399 : cachingRequest != null ? requestBody .length : -1 ;
94100 final long responseContentLength = getResponseContentLength (response );
95- final long responseSize = responseContentLength >= 0
96- ? responseContentLength
97- : (cachingResponse != null ? responseBody .length : -1 );
101+ final long responseSize = responseContentLength >= 0 ? responseContentLength
102+ : (cachingResponse != null ? responseBody .length
103+ : countingResponse != null
104+ ? countingResponse .getByteCount ()
105+ : -1 );
98106 client .requestCounter
99107 .addRequest (consumerIdentifier , request .getMethod (), path , response .getStatus (),
100108 responseTimeInMillis ,
@@ -154,7 +162,6 @@ protected void doFilterInternal(@NonNull HttpServletRequest request, @NonNull Ht
154162 logger .error ("Error in Apitally filter" , e );
155163 }
156164 }
157-
158165 }
159166
160167 private static long getResponseContentLength (HttpServletResponse response ) {
@@ -167,4 +174,76 @@ private static long getResponseContentLength(HttpServletResponse response) {
167174 }
168175 return -1L ;
169176 }
177+
178+ private static class CountingResponseWrapper extends HttpServletResponseWrapper {
179+ private CountingServletOutputStream countingStream ;
180+
181+ public CountingResponseWrapper (HttpServletResponse response ) {
182+ super (response );
183+ }
184+
185+ @ Override
186+ public ServletOutputStream getOutputStream () throws IOException {
187+ if (countingStream == null ) {
188+ countingStream = new CountingServletOutputStream (super .getOutputStream ());
189+ }
190+ return countingStream ;
191+ }
192+
193+ public long getByteCount () {
194+ return countingStream != null ? countingStream .getByteCount () : 0 ;
195+ }
196+ }
197+
198+ private static class CountingServletOutputStream extends ServletOutputStream {
199+ private final ServletOutputStream outputStream ;
200+ private long byteCount ;
201+
202+ public CountingServletOutputStream (ServletOutputStream outputStream ) {
203+ this .outputStream = outputStream ;
204+ this .byteCount = 0 ;
205+ }
206+
207+ @ Override
208+ public boolean isReady () {
209+ return outputStream .isReady ();
210+ }
211+
212+ @ Override
213+ public void setWriteListener (WriteListener writeListener ) {
214+ outputStream .setWriteListener (writeListener );
215+ }
216+
217+ @ Override
218+ public void write (int b ) throws IOException {
219+ outputStream .write (b );
220+ byteCount ++;
221+ }
222+
223+ @ Override
224+ public void write (byte [] b ) throws IOException {
225+ outputStream .write (b );
226+ byteCount += b .length ;
227+ }
228+
229+ @ Override
230+ public void write (byte [] b , int off , int len ) throws IOException {
231+ outputStream .write (b , off , len );
232+ byteCount += len ;
233+ }
234+
235+ @ Override
236+ public void flush () throws IOException {
237+ outputStream .flush ();
238+ }
239+
240+ @ Override
241+ public void close () throws IOException {
242+ outputStream .close ();
243+ }
244+
245+ public long getByteCount () {
246+ return byteCount ;
247+ }
248+ }
170249}
0 commit comments