Skip to content

Commit cd27e53

Browse files
committed
Fix python code in contrib
1 parent d8ddd3b commit cd27e53

File tree

5 files changed

+19
-17
lines changed

5 files changed

+19
-17
lines changed

CMakeLists.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ include(external/snappy) # download snappy
201201
include(external/snappystream)
202202
include(external/threadpool)
203203

204+
if(WITH_GPU)
205+
include(cuda)
206+
include(tensorrt)
207+
include(external/anakin)
208+
else()
209+
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
210+
endif()
211+
204212
include(cudnn) # set cudnn libraries, must before configure
205213
include(cupti)
206214
include(configure) # add paddle env configuration
@@ -229,14 +237,6 @@ set(EXTERNAL_LIBS
229237
${PYTHON_LIBRARIES}
230238
)
231239

232-
if(WITH_GPU)
233-
include(cuda)
234-
include(tensorrt)
235-
include(external/anakin)
236-
else()
237-
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when GPU is set." FORCE)
238-
endif()
239-
240240
if(WITH_AMD_GPU)
241241
find_package(HIP)
242242
include(hip)

cmake/cudnn.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
2121
${CUDNN_ROOT}/lib64
2222
${CUDNN_ROOT}/lib
2323
${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu
24+
${CUDNN_ROOT}/local/cuda-${CUDA_VERSION}/targets/${TARGET_ARCH}-linux/lib/
2425
$ENV{CUDNN_ROOT}
2526
$ENV{CUDNN_ROOT}/lib64
2627
$ENV{CUDNN_ROOT}/lib

python/paddle/fluid/contrib/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import decoder
16-
from decoder import *
15+
from . import decoder
16+
from .decoder import *
1717

1818
__all__ = decoder.__all__

python/paddle/fluid/contrib/decoder/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import beam_search_decoder
16-
from beam_search_decoder import *
15+
from . import beam_search_decoder
16+
from .beam_search_decoder import *
1717

1818
__all__ = beam_search_decoder.__all__

python/paddle/fluid/contrib/decoder/beam_search_decoder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __init__(self, inputs, states, out_state, name=None):
191191
self._helper = LayerHelper('state_cell', name=name)
192192
self._cur_states = {}
193193
self._state_names = []
194-
for state_name, state in states.items():
194+
for state_name, state in six.iteritems(states):
195195
if not isinstance(state, InitState):
196196
raise ValueError('state must be an InitState object.')
197197
self._cur_states[state_name] = state
@@ -346,7 +346,7 @@ def compute_state(self, inputs):
346346
if self._in_decoder and not self._switched_decoder:
347347
self._switch_decoder()
348348

349-
for input_name, input_value in inputs.items():
349+
for input_name, input_value in six.iteritems(inputs):
350350
if input_name not in self._inputs:
351351
raise ValueError('Unknown input %s. '
352352
'Please make sure %s in input '
@@ -361,7 +361,7 @@ def update_states(self):
361361
if self._in_decoder and not self._switched_decoder:
362362
self._switched_decoder()
363363

364-
for state_name, decoder_state in self._states_holder.items():
364+
for state_name, decoder_state in six.iteritems(self._states_holder):
365365
if id(self._cur_decoder_obj) not in decoder_state:
366366
raise ValueError('Unknown decoder object, please make sure '
367367
'switch_decoder been invoked.')
@@ -671,7 +671,7 @@ def decode(self):
671671
feed_dict = {}
672672
update_dict = {}
673673

674-
for init_var_name, init_var in self._input_var_dict.items():
674+
for init_var_name, init_var in six.iteritems(self._input_var_dict):
675675
if init_var_name not in self.state_cell._inputs:
676676
raise ValueError('Variable ' + init_var_name +
677677
' not found in StateCell!\n')
@@ -721,7 +721,8 @@ def decode(self):
721721
self.state_cell.update_states()
722722
self.update_array(prev_ids, selected_ids)
723723
self.update_array(prev_scores, selected_scores)
724-
for update_name, var_to_update in update_dict.items():
724+
for update_name, var_to_update in six.iteritems(
725+
update_dict):
725726
self.update_array(var_to_update, feed_dict[update_name])
726727

727728
def read_array(self, init, is_ids=False, is_scores=False):

0 commit comments

Comments
 (0)