88#include < iostream>
99#include < algorithm>
1010#include < stdexcept>
11+ #include < array>
1112
1213namespace facebook {
1314namespace graphql {
@@ -48,7 +49,7 @@ const peg::ast_node& Fragment::getSelection() const
4849 return _selection;
4950}
5051
51- uint8_t Base64::verifyFromBase64 (unsigned char ch)
52+ uint8_t Base64::verifyFromBase64 (char ch)
5253{
5354 uint8_t result = fromBase64 (ch);
5455
@@ -60,70 +61,94 @@ uint8_t Base64::verifyFromBase64(unsigned char ch)
6061 return result;
6162}
6263
63- std::vector<unsigned char > Base64::fromBase64 (const char * encoded, size_t count)
64+ std::vector<uint8_t > Base64::fromBase64 (const char * encoded, size_t count)
6465{
65- std::vector<unsigned char > result;
66+ std::vector<uint8_t > result;
6667
6768 if (!count)
6869 {
6970 return result;
7071 }
7172
72- result.reserve (count * 3 / 4 );
73- while (encoded[0 ] && encoded[1 ])
73+ result.reserve ((count + (count % 4 )) * 3 / 4 );
74+
75+ // First decode all of the full unpadded segments 24 bits at a time
76+ while (count >= 4
77+ && encoded[3 ] != padding)
7478 {
75- uint16_t buffer = static_cast <uint16_t >(verifyFromBase64 (*encoded++)) << 10 ;
79+ const uint32_t segment = (static_cast <uint32_t >(verifyFromBase64 (encoded[0 ])) << 18 )
80+ | (static_cast <uint32_t >(verifyFromBase64 (encoded[1 ])) << 12 )
81+ | (static_cast <uint32_t >(verifyFromBase64 (encoded[2 ])) << 6 )
82+ | static_cast <uint32_t >(verifyFromBase64 (encoded[3 ]));
83+
84+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF0000 ) >> 16 ));
85+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF00 ) >> 8 ));
86+ result.emplace_back (static_cast <uint8_t >(segment & 0xFF ));
87+
88+ encoded += 4 ;
89+ count -= 4 ;
90+ }
7691
77- buffer |= static_cast <uint16_t >(verifyFromBase64 (*encoded++)) << 4 ;
78- result.push_back (static_cast <unsigned char >((buffer & 0xFF00 ) >> 8 ));
79- buffer = (buffer & 0xFF ) << 8 ;
92+ // Get any leftover partial segment with 2 or 3 non-padding characters
93+ if (count > 1 )
94+ {
95+ const bool triplet = (count > 2 && padding != encoded[2 ]);
96+ const uint8_t tail = (triplet ? verifyFromBase64 (encoded[2 ]) : 0 );
97+ const uint16_t segment = (static_cast <uint16_t >(verifyFromBase64 (encoded[0 ])) << 10 )
98+ | (static_cast <uint16_t >(verifyFromBase64 (encoded[1 ])) << 4 )
99+ | (static_cast <uint16_t >(tail) >> 2 );
80100
81- if (!*encoded || ' = ' == *encoded )
101+ if (triplet )
82102 {
83- if (0 != buffer
84- || (*encoded && (*++encoded != ' =' || *++encoded)))
103+ if (tail & 0x3 )
85104 {
86105 throw schema_exception ({ " invalid padding at the end of a base64 encoded string" });
87106 }
88107
89- break ;
90- }
91-
92- buffer |= static_cast <uint16_t >(verifyFromBase64 (*encoded++)) << 6 ;
93- result.push_back (static_cast <unsigned char >((buffer & 0xFF00 ) >> 8 ));
94- buffer &= 0xFF ;
108+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF00 ) >> 8 ));
109+ result.emplace_back (static_cast <uint8_t >(segment & 0xFF ));
95110
96- if (!*encoded || ' =' == *encoded)
111+ encoded += 3 ;
112+ count -= 3 ;
113+ }
114+ else
97115 {
98- if (0 != buffer
99- || (*encoded && *++encoded))
116+ if (segment & 0xFF )
100117 {
101118 throw schema_exception ({ " invalid padding at the end of a base64 encoded string" });
102119 }
103120
104- break ;
121+ result.emplace_back (static_cast <uint8_t >((segment & 0xFF00 ) >> 8 ));
122+
123+ encoded += 2 ;
124+ count -= 2 ;
105125 }
126+ }
106127
107- buffer |= static_cast <uint16_t >(verifyFromBase64 (*encoded++));
108- result.push_back (static_cast <unsigned char >(buffer & 0xFF ));
128+ // Make sure anything that's left is 0 - 2 characters of padding
129+ if ((count > 0 && padding != encoded[0 ])
130+ || (count > 1 && padding != encoded[1 ])
131+ || count > 2 )
132+ {
133+ throw schema_exception ({ " invalid padding at the end of a base64 encoded string" });
109134 }
110135
111136 return result;
112137}
113138
114- unsigned char Base64::verifyToBase64 (uint8_t i)
139+ char Base64::verifyToBase64 (uint8_t i)
115140{
116141 unsigned char result = toBase64 (i);
117142
118- if (result == ' = ' )
143+ if (result == padding )
119144 {
120145 throw std::logic_error (" invalid 6-bit value" );
121146 }
122147
123148 return result;
124149}
125150
126- std::string Base64::toBase64 (const std::vector<unsigned char >& bytes)
151+ std::string Base64::toBase64 (const std::vector<uint8_t >& bytes)
127152{
128153 std::string result;
129154
@@ -132,38 +157,43 @@ std::string Base64::toBase64(const std::vector<unsigned char>& bytes)
132157 return result;
133158 }
134159
135- auto itr = bytes.cbegin ();
136- const auto itrEnd = bytes.cend ();
137- const size_t count = bytes.size ();
160+ size_t count = bytes.size ();
161+ const uint8_t * data = bytes.data ();
138162
139163 result.reserve ((count + (count % 3 )) * 4 / 3 );
140- while (itr != itrEnd)
141- {
142- uint16_t buffer = static_cast <uint8_t >(*itr++) << 8 ;
143-
144- result.push_back (verifyToBase64 ((buffer & 0xFC00 ) >> 10 ));
145-
146- if (itr == itrEnd)
147- {
148- result.push_back (verifyToBase64 ((buffer & 0x03F0 ) >> 4 ));
149- result.append (" ==" );
150- break ;
151- }
152164
153- buffer |= static_cast <uint8_t >(*itr++);
154- result.push_back (verifyToBase64 ((buffer & 0x03F0 ) >> 4 ));
155- buffer = buffer << 8 ;
165+ // First encode all of the full unpadded segments 24 bits at a time
166+ while (count >= 3 )
167+ {
168+ const uint32_t segment = (static_cast <uint32_t >(data[0 ]) << 16 )
169+ | (static_cast <uint32_t >(data[1 ]) << 8 )
170+ | static_cast <uint32_t >(data[2 ]);
171+
172+ result.append ({
173+ verifyToBase64 ((segment & 0xFC0000 ) >> 18 ),
174+ verifyToBase64 ((segment & 0x3F000 ) >> 12 ),
175+ verifyToBase64 ((segment & 0xFC0 ) >> 6 ),
176+ verifyToBase64 (segment & 0x3F )
177+ });
156178
157- if (itr == itrEnd)
158- {
159- result.push_back (verifyToBase64 ((buffer & 0x0FC0 ) >> 6 ));
160- result.push_back (' =' );
161- break ;
162- }
179+ data += 3 ;
180+ count -= 3 ;
181+ }
163182
164- buffer |= static_cast <uint8_t >(*itr++);
165- result.push_back (verifyToBase64 ((buffer & 0x0FC0 ) >> 6 ));
166- result.push_back (verifyToBase64 (buffer & 0x3F ));
183+ // Get any leftover partial segment with 1 or 2 bytes
184+ if (count > 0 )
185+ {
186+ const bool pair = (count > 1 );
187+ const uint16_t segment = (static_cast <uint16_t >(data[0 ]) << 8 )
188+ | (pair ? static_cast <uint16_t >(data[1 ]) : 0 );
189+ const std::array<char , 4 > remainder {
190+ verifyToBase64 ((segment & 0xFC00 ) >> 10 ),
191+ verifyToBase64 ((segment & 0x3F0 ) >> 4 ),
192+ (pair ? verifyToBase64 ((segment & 0xF ) << 2 ) : padding),
193+ padding
194+ };
195+
196+ result.append (remainder.data (), remainder.size ());
167197 }
168198
169199 return result;
@@ -229,7 +259,7 @@ rapidjson::Document ModifiedArgument<rapidjson::Document>::convert(const rapidjs
229259}
230260
231261template <>
232- std::vector<unsigned char > ModifiedArgument<std::vector<unsigned char >>::convert(const rapidjson::Value& value)
262+ std::vector<uint8_t > ModifiedArgument<std::vector<uint8_t >>::convert(const rapidjson::Value& value)
233263{
234264 if (!value.IsString ())
235265 {
@@ -282,7 +312,7 @@ rapidjson::Document ModifiedResult<rapidjson::Document>::convert(rapidjson::Docu
282312}
283313
284314template <>
285- rapidjson::Document ModifiedResult<std::vector<unsigned char >>::convert(std::vector<unsigned char >&& result, ResolverParams&&)
315+ rapidjson::Document ModifiedResult<std::vector<uint8_t >>::convert(std::vector<uint8_t >&& result, ResolverParams&&)
286316{
287317 rapidjson::Document document (rapidjson::Type::kStringType );
288318
0 commit comments