Skip to content

Commit a6fe62c

Browse files
oauth2.Client now handles multipart/form-data properly
- added a test (with a reliance on the `mox` mocking library) to verify this - multipart requests put the authorization in the header, and do not touch the body.
1 parent e7427b6 commit a6fe62c

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

oauth2/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -584,11 +584,14 @@ def set_signature_method(self, method):
584584

585585
def request(self, uri, method="GET", body=None, headers=None,
586586
redirections=httplib2.DEFAULT_MAX_REDIRECTS, connection_type=None):
587-
587+
DEFAULT_CONTENT_TYPE = 'application/x-www-form-urlencoded'
588+
588589
if not isinstance(headers, dict):
589590
headers = {}
590591

591-
if body and method == "POST":
592+
is_multipart = method == 'POST' and headers.get('Content-Type', DEFAULT_CONTENT_TYPE) != DEFAULT_CONTENT_TYPE
593+
594+
if body and method == "POST" and not is_multipart:
592595
parameters = dict(parse_qsl(body))
593596
elif method == "GET":
594597
parsed = urlparse.urlparse(uri)
@@ -607,9 +610,13 @@ def request(self, uri, method="GET", body=None, headers=None,
607610

608611
req.sign_request(self.method, self.consumer, self.token)
609612

613+
610614
if method == "POST":
611-
body = req.to_postdata()
612-
headers['Content-Type'] = 'application/x-www-form-urlencoded'
615+
headers['Content-Type'] = headers.get('Content-Type', DEFAULT_CONTENT_TYPE)
616+
if is_multipart:
617+
headers.update(req.to_header())
618+
else:
619+
body = req.to_postdata()
613620
elif method == "GET":
614621
uri = req.to_url()
615622
else:

tests/test_oauth.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
import urllib
3232
import urlparse
3333
from types import ListType
34-
34+
import mox
35+
import httplib2
3536

3637
# Fix for python2.5 compatibility
3738
try:
@@ -737,6 +738,7 @@ class TestClient(unittest.TestCase):
737738
host = 'http://oauth-sandbox.sevengoslings.net'
738739

739740
def setUp(self):
741+
self.mox = mox.Mox()
740742
self.consumer = oauth.Consumer(key=self.consumer_key,
741743
secret=self.consumer_secret)
742744

@@ -747,13 +749,31 @@ def setUp(self):
747749
'blah': 599999
748750
}
749751

752+
def tearDown(self):
753+
self.mox.UnsetStubs()
754+
750755
def _uri(self, type):
751756
uri = self.oauth_uris.get(type)
752757
if uri is None:
753758
raise KeyError("%s is not a valid OAuth URI type." % type)
754759

755760
return "%s%s" % (self.host, uri)
756761

762+
def create_simple_multipart_data(self, data):
763+
boundary = '---Boundary-%d' % random.randint(1,1000)
764+
crlf = '\r\n'
765+
items = []
766+
for key, value in data.iteritems():
767+
items += [
768+
'--'+boundary,
769+
'Content-Disposition: form-data; name="%s"'%str(key),
770+
'',
771+
str(value),
772+
]
773+
items += ['', '--'+boundary+'--', '']
774+
content_type = 'multipart/form-data; boundary=%s' % boundary
775+
return content_type, crlf.join(items)
776+
757777
def test_access_token_get(self):
758778
"""Test getting an access token via GET."""
759779
client = oauth.Client(self.consumer, None)
@@ -789,6 +809,32 @@ def test_two_legged_get(self):
789809
resp, content = self._two_legged("GET")
790810
self.assertEquals(int(resp['status']), 200)
791811

812+
def test_multipart_post_does_not_alter_body(self):
813+
self.mox.StubOutWithMock(httplib2.Http, 'request')
814+
random_result = random.randint(1,100)
815+
816+
data = {
817+
'rand-%d'%random.randint(1,100):random.randint(1,100),
818+
}
819+
content_type, body = self.create_simple_multipart_data(data)
820+
821+
client = oauth.Client(self.consumer, None)
822+
uri = self._uri('two_legged')
823+
824+
expected_kwargs = {
825+
'method':'POST',
826+
'body':body,
827+
'redirections':httplib2.DEFAULT_MAX_REDIRECTS,
828+
'connection_type':None,
829+
'headers':mox.IsA(dict),
830+
}
831+
httplib2.Http.request(client, uri, **expected_kwargs).AndReturn(random_result)
832+
833+
self.mox.ReplayAll()
834+
result = client.request(uri, 'POST', headers={'Content-Type':content_type}, body=body)
835+
self.assertEqual(result, random_result)
836+
self.mox.VerifyAll()
837+
792838
if __name__ == "__main__":
793839
unittest.main()
794840

0 commit comments

Comments
 (0)