Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 129 additions & 55 deletions oauth2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@
import urllib
import time
import random
import urlparse
import hmac
import binascii
import httplib2

try:
import urlparse
except ImportError:
# urlparse location changed in python 3
from urllib import parse as urlparse

try:
from urlparse import parse_qs
parse_qs # placate pyflakes
Expand All @@ -39,8 +44,7 @@
from cgi import parse_qs

try:
from hashlib import sha1
sha = sha1
from hashlib import sha1 as sha
except ImportError:
# hashlib was added in Python 2.5
import sha
Expand All @@ -49,7 +53,7 @@

__version__ = _version.__version__

OAUTH_VERSION = '1.0' # Hi Blaine!
OAUTH_VERSION = '1.0'
HTTP_METHOD = 'GET'
SIGNATURE_METHOD = 'PLAINTEXT'

Expand Down Expand Up @@ -87,7 +91,7 @@ def build_xoauth_string(url, consumer, token=None):
request.sign_request(signing_method, consumer, token)

params = []
for k, v in sorted(request.iteritems()):
for k, v in sorted(request.items()):
if v is not None:
params.append('%s="%s"' % (k, escape(v)))

Expand All @@ -97,41 +101,66 @@ def build_xoauth_string(url, consumer, token=None):
def to_unicode(s):
""" Convert to unicode, raise exception with instructive error
message if s is not unicode, ascii, or utf-8. """
if not isinstance(s, unicode):
# Python 3 strings are unicode (utf-8) by default
try:
if not isinstance(s, unicode):
if not isinstance(s, str):
raise TypeError('You are required to pass either unicode or string here, not: %r (%s)' % (type(s), s))
try:
s = s.decode('utf-8')
except UnicodeDecodeError as le:
raise TypeError('You are required to pass either a unicode object or a utf-8 string here. You passed a Python string object which contained non-utf-8: %r. The UnicodeDecodeError that resulted from attempting to interpret it as utf-8 was: %s' % (s, le,))
except NameError:
if not isinstance(s, str):
raise TypeError('You are required to pass either unicode or string here, not: %r (%s)' % (type(s), s))
try:
s = s.decode('utf-8')
except UnicodeDecodeError, le:
s = s.encode('utf-8')
except UnicodeDecodeError as le:
raise TypeError('You are required to pass either a unicode object or a utf-8 string here. You passed a Python string object which contained non-utf-8: %r. The UnicodeDecodeError that resulted from attempting to interpret it as utf-8 was: %s' % (s, le,))
return s

def to_utf8(s):
return to_unicode(s).encode('utf-8')

def to_unicode_if_string(s):
if isinstance(s, basestring):
return to_unicode(s)
else:
return s
try:
if isinstance(s, basestring):
return to_unicode(s)
else:
return s
except NameError:
if isinstance(s, str):
return to_unicode(s)
else:
return s

def to_utf8_if_string(s):
if isinstance(s, basestring):
return to_utf8(s)
else:
return s
try:
if isinstance(s, basestring):
return to_utf8(s)
else:
return s
except NameError:
if isinstance(s, str):
return to_utf8(s)
else:
return s

def to_unicode_optional_iterator(x):
"""
Raise TypeError if x is a str containing non-utf8 bytes or if x is
an iterable which contains such a str.
"""
if isinstance(x, basestring):
return to_unicode(x)
try:
if isinstance(x, basestring):
return to_unicode(x)
except NameError:
if isinstance(x, str):
return to_unicode(x)

try:
l = list(x)
except TypeError, e:
except TypeError as e:
assert 'is not iterable' in str(e)
return x
else:
Expand All @@ -142,20 +171,27 @@ def to_utf8_optional_iterator(x):
Raise TypeError if x is a str or if x is an iterable which
contains a str.
"""
if isinstance(x, basestring):
return to_utf8(x)
try:
if isinstance(x, basestring):
return to_utf8(x)
except NameError:
if isinstance(x, str):
return to_utf8(x)

try:
l = list(x)
except TypeError, e:
except TypeError as e:
assert 'is not iterable' in str(e)
return x
else:
return [ to_utf8_if_string(e) for e in l ]

def escape(s):
"""Escape a URL including any /."""
return urllib.quote(s.encode('utf-8'), safe='~')
try:
return urllib.quote(s.encode('utf-8'), safe='~')
except AttributeError:
return urlparse.quote(s.encode('utf-8'), safe='~')

def generate_timestamp():
"""Get seconds since epoch (UTC)."""
Expand Down Expand Up @@ -205,8 +241,10 @@ def __init__(self, key, secret):
def __str__(self):
data = {'oauth_consumer_key': self.key,
'oauth_consumer_secret': self.secret}

return urllib.urlencode(data)
try:
return urllib.urlencode(data)
except AttributeError:
return urlparse.urlencode(data)


class Token(object):
Expand Down Expand Up @@ -274,7 +312,10 @@ def to_string(self):

if self.callback_confirmed is not None:
data['oauth_callback_confirmed'] = self.callback_confirmed
return urllib.urlencode(data)
try:
return urllib.urlencode(data)
except AttributeError:
return urlparse.urlencode(data)

@staticmethod
def from_string(s):
Expand Down Expand Up @@ -345,7 +386,7 @@ def __init__(self, method=HTTP_METHOD, url=None, parameters=None,
self.url = to_unicode(url)
self.method = method
if parameters is not None:
for k, v in parameters.iteritems():
for k, v in parameters.items():
k = to_unicode(k)
v = to_unicode_optional_iterator(v)
self[k] = v
Expand Down Expand Up @@ -382,7 +423,7 @@ def _get_timestamp_nonce(self):

def get_nonoauth_parameters(self):
"""Get any non-OAuth parameters."""
return dict([(k, v) for k, v in self.iteritems()
return dict([(k, v) for k, v in self.items()
if not k.startswith('oauth_')])

def to_header(self, realm=''):
Expand All @@ -402,13 +443,16 @@ def to_header(self, realm=''):
def to_postdata(self):
"""Serialize as post data for a POST request."""
d = {}
for k, v in self.iteritems():
for k, v in self.items():
d[k.encode('utf-8')] = to_utf8_optional_iterator(v)

# tell urlencode to deal with sequence values and map them correctly
# to resulting querystring. for example self["k"] = ["v1", "v2"] will
# result in 'k=v1&k=v2' and not k=%5B%27v1%27%2C+%27v2%27%5D
return urllib.urlencode(d, True).replace('+', '%20')
try:
return urllib.urlencode(d, True).replace('+', '%20')
except AttributeError:
return urlparse.urlencode(d, True).replace('+', '%20')

def to_url(self):
"""Serialize as a URL for a GET request."""
Expand All @@ -430,15 +474,20 @@ def to_url(self):
fragment = to_utf8(base_url.fragment)
except AttributeError:
# must be python <2.5
scheme = to_utf8(base_url[0])
netloc = to_utf8(base_url[1])
path = to_utf8(base_url[2])
params = to_utf8(base_url[3])
fragment = to_utf8(base_url[5])

url = (scheme, netloc, path, params,
urllib.urlencode(query, True), fragment)
return urlparse.urlunparse(url)
scheme = base_url[0]
netloc = base_url[1]
path = base_url[2]
params = base_url[3]
fragment = base_url[5]

try:
url = (scheme, netloc, path, params,
urllib.urlencode(query, True), fragment)
return urllib.urlunparse(url)
except AttributeError:
url = (scheme, netloc, path, params,
urlparse.urlencode(query, True), fragment)
return urlparse.urlunparse(url)

def get_parameter(self, parameter):
ret = self.get(parameter)
Expand All @@ -450,21 +499,33 @@ def get_parameter(self, parameter):
def get_normalized_parameters(self):
"""Return a string that contains the parameters that must be signed."""
items = []
for key, value in self.iteritems():
for key, value in self.items():
if key == 'oauth_signature':
continue
# 1.0a/9.1.1 states that kvp must be sorted by key, then by value,
# so we unpack sequence values into multiple items for sorting.
if isinstance(value, basestring):
items.append((to_utf8_if_string(key), to_utf8(value)))
else:
try:
value = list(value)
except TypeError, e:
assert 'is not iterable' in str(e)
items.append((to_utf8_if_string(key), to_utf8_if_string(value)))
try:
if isinstance(value, basestring):
items.append((to_utf8_if_string(key), to_utf8(value)))
else:
items.extend((to_utf8_if_string(key), to_utf8_if_string(item)) for item in value)
try:
value = list(value)
except TypeError as e:
assert 'is not iterable' in str(e)
items.append((to_utf8_if_string(key), to_utf8_if_string(value)))
else:
items.extend((to_utf8_if_string(key), to_utf8_if_string(item)) for item in value)
except NameError:
if isinstance(value, str):
items.append((to_utf8_if_string(key), to_utf8(value)))
else:
try:
value = list(value)
except TypeError as e:
assert 'is not iterable' in str(e)
items.append((to_utf8_if_string(key), to_utf8_if_string(value)))
else:
items.extend((to_utf8_if_string(key), to_utf8_if_string(item)) for item in value)

# Include any query string parameters from the provided URL
query = urlparse.urlparse(self.url)[4]
Expand All @@ -475,6 +536,10 @@ def get_normalized_parameters(self):

items.sort()
encoded_str = urllib.urlencode(items, True)
try:
encoded_str = urllib.urlencode(items)
except AttributeError:
encoded_str = urlparse.urlencode(items)
# Encode signature parameters per Oauth Core 1.0 protocol
# spec draft 7, section 3.6
# (http://tools.ietf.org/html/draft-hammer-oauth-07#section-3.6)
Expand All @@ -490,7 +555,7 @@ def sign_request(self, signature_method, consumer, token):
# section 4.1.1 "OAuth Consumers MUST NOT include an
# oauth_body_hash parameter on requests with form-encoded
# request bodies."
self['oauth_body_hash'] = base64.b64encode(sha(self.body).digest())
self['oauth_body_hash'] = base64.b64encode(sha(self.body.encode("utf-8")).digest())

if 'oauth_consumer_key' not in self:
self['oauth_consumer_key'] = consumer.key
Expand Down Expand Up @@ -605,7 +670,10 @@ def _split_header(header):
# Split key-value.
param_parts = param.split('=', 1)
# Remove quotes and unescape the value.
params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
try:
params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
except AttributeError:
params[param_parts[0]] = urlparse.unquote(param_parts[1].strip('\"'))
return params

@staticmethod
Expand Down Expand Up @@ -667,13 +735,19 @@ def request(self, uri, method="GET", body='', headers=None,
parameters=parameters, body=body, is_form_encoded=is_form_encoded)

req.sign_request(self.method, self.consumer, self.token)

schema, rest = urllib.splittype(uri)

try:
schema, rest = urllib.splittype(uri)
except AttributeError:
schema, rest = urlparse.splittype(uri)
if rest.startswith('//'):
hierpart = '//'
else:
hierpart = ''
host, rest = urllib.splithost(rest)
try:
host, rest = urllib.splithost(rest)
except AttributeError:
host, rest = urlparse.splithost(rest)

realm = schema + ':' + hierpart + host

Expand Down Expand Up @@ -844,7 +918,7 @@ def sign(self, request, consumer, token):
"""Builds the base signature string."""
key, raw = self.signing_base(request, consumer, token)

hashed = hmac.new(key, raw, sha)
hashed = hmac.new(key.encode('utf-8'), raw.encode('utf-8'), sha)

# Calculate the digest base 64.
return binascii.b2a_base64(hashed.digest())[:-1]
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python

from setuptools import setup
from __future__ import print_function
import os, re

PKG='oauth2'
Expand All @@ -15,7 +17,7 @@
if mo:
mverstr = mo.group(1)
else:
print "unable to find version in %s" % (VERSIONFILE,)
print ("unable to find version in %s") % (VERSIONFILE,)
raise RuntimeError("if %s.py exists, it must be well-formed" % (VERSIONFILE,))
AVSRE = r"^auto_build_num *= *['\"]([^'\"]*)['\"]"
mo = re.search(AVSRE, verstrline, re.M)
Expand All @@ -36,5 +38,6 @@
license = "MIT License",
keywords="oauth",
zip_safe = True,
use_2to3=True,
test_suite="tests",
tests_require=['coverage', 'mock'])
Loading