22// Licensed under the MIT License.
33
44use crate :: http:: {
5+ headers:: { HeaderValue , CONTENT_LENGTH } ,
56 options:: TransportOptions ,
67 policies:: { Policy , PolicyResult } ,
7- Context , Request ,
8+ Context , Header , Request ,
89} ;
910use async_trait:: async_trait;
11+ use http_types:: Method ;
1012use std:: sync:: Arc ;
1113use tracing:: debug;
1214
@@ -33,9 +35,111 @@ impl Policy for TransportPolicy {
3335 // there must be no more policies
3436 assert_eq ! ( 0 , next. len( ) ) ;
3537
38+ if request. body ( ) . is_empty ( )
39+ && matches ! (
40+ * request. method( ) ,
41+ Method :: Patch | Method :: Post | Method :: Put
42+ )
43+ {
44+ request. add_mandatory_header ( EMPTY_CONTENT_LENGTH ) ;
45+ }
46+
3647 debug ! ( ?request, "sending request '{}'" , request. url) ;
3748 let response = { self . transport_options . send ( ctx, request) } ;
3849
3950 response. await
4051 }
4152}
53+
54+ const EMPTY_CONTENT_LENGTH : & EmptyContentLength = & EmptyContentLength ;
55+
56+ struct EmptyContentLength ;
57+
58+ impl Header for EmptyContentLength {
59+ fn name ( & self ) -> crate :: http:: headers:: HeaderName {
60+ CONTENT_LENGTH
61+ }
62+ fn value ( & self ) -> crate :: http:: headers:: HeaderValue {
63+ HeaderValue :: from ( "0" )
64+ }
65+ }
66+
67+ #[ cfg( all( test, not( target_family = "wasm" ) ) ) ]
68+ mod tests {
69+ use super :: * ;
70+ use crate :: http:: { headers:: Headers , Response } ;
71+ use http_types:: StatusCode ;
72+
73+ #[ derive( Debug ) ]
74+ struct MockTransport ;
75+
76+ #[ async_trait]
77+ impl Policy for MockTransport {
78+ async fn send (
79+ & self ,
80+ _ctx : & Context ,
81+ _request : & mut Request ,
82+ _next : & [ Arc < dyn Policy > ] ,
83+ ) -> PolicyResult {
84+ PolicyResult :: Ok ( Response :: from_bytes (
85+ StatusCode :: Ok ,
86+ Headers :: new ( ) ,
87+ Vec :: new ( ) ,
88+ ) )
89+ }
90+ }
91+
92+ #[ tokio:: test]
93+ async fn test_content_length ( ) -> std:: result:: Result < ( ) , Box < dyn std:: error:: Error > > {
94+ let transport =
95+ TransportPolicy :: new ( TransportOptions :: new_custom_policy ( Arc :: new ( MockTransport ) ) ) ;
96+
97+ let mut request = Request :: new ( "http://localhost" . parse ( ) ?, Method :: Get ) ;
98+ transport. send ( & Context :: new ( ) , & mut request, & [ ] ) . await ?;
99+ assert ! ( !request. headers( ) . iter( ) . any( |h| CONTENT_LENGTH . eq( h. 0 ) ) ) ;
100+
101+ request. headers = Headers :: new ( ) ;
102+ request. method = Method :: Patch ;
103+ transport. send ( & Context :: new ( ) , & mut request, & [ ] ) . await ?;
104+ assert_eq ! (
105+ request
106+ . headers( )
107+ . get_with( & CONTENT_LENGTH , |v| v. as_str( ) . parse:: <u16 >( ) )
108+ . unwrap( ) ,
109+ 0u16
110+ ) ;
111+
112+ request. headers = Headers :: new ( ) ;
113+ request. method = Method :: Post ;
114+ transport. send ( & Context :: new ( ) , & mut request, & [ ] ) . await ?;
115+ assert_eq ! (
116+ request
117+ . headers( )
118+ . get_with( & CONTENT_LENGTH , |v| v. as_str( ) . parse:: <u16 >( ) )
119+ . unwrap( ) ,
120+ 0u16
121+ ) ;
122+
123+ request. headers = Headers :: new ( ) ;
124+ request. method = Method :: Put ;
125+ transport. send ( & Context :: new ( ) , & mut request, & [ ] ) . await ?;
126+ assert_eq ! (
127+ request
128+ . headers( )
129+ . get_with( & CONTENT_LENGTH , |v| v. as_str( ) . parse:: <u16 >( ) )
130+ . unwrap( ) ,
131+ 0u16
132+ ) ;
133+
134+ // The HttpClient would add this normally.
135+ request. headers = Headers :: new ( ) ;
136+ request. body = "{}" . into ( ) ;
137+ transport. send ( & Context :: new ( ) , & mut request, & [ ] ) . await ?;
138+ request
139+ . headers ( )
140+ . get_with ( & CONTENT_LENGTH , |v| v. as_str ( ) . parse :: < u16 > ( ) )
141+ . expect_err ( "expected no content-length header" ) ;
142+
143+ Ok ( ( ) )
144+ }
145+ }
0 commit comments